diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 88ac5fd235..d5d3feded0 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -59,7 +59,7 @@ class TaskState(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -72,7 +72,7 @@ class TaskState(BaseModel): metadata: dict = {} usage: LLMUsage - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -168,8 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) @@ -218,8 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -234,8 +232,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -253,8 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -285,8 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -435,7 +430,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -452,11 +447,16 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -465,19 +465,26 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -495,19 +502,24 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -516,7 +528,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -524,39 +536,30 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_message(self) -> None: """ @@ -567,7 +570,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.workflow_run_id = self._task_state.workflow_run.id + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9bd20f9785..8516feb87d 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -45,7 +45,7 @@ class TaskState(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -57,7 +57,7 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -130,8 +130,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -179,8 +178,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -195,8 +193,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -214,8 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) response = { 'event': 'node_finished', @@ -240,8 +236,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -257,7 +252,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): replace_response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': self._task_state.answer } @@ -317,7 +312,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': event.text } @@ -329,7 +324,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -344,11 +339,16 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -357,7 +357,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) @@ -366,11 +366,17 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -383,19 +389,24 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) # remove running node execution info del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -404,7 +415,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -412,39 +423,30 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_workflow_app_log(self) -> None: """ @@ -461,7 +463,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ response = { 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'task_id': self._application_generate_entity.task_id, 'data': { 'text': text diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index d29cee3ac4..2b373d28e8 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -87,6 +87,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -115,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -185,6 +187,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution @@ -205,6 +208,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution