diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 624a0f430a..c1076fa947 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -47,6 +47,7 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} usage: LLMUsage + workflow_run_id: Optional[str] = None class AdvancedChatAppGenerateTaskPipeline: @@ -110,6 +111,8 @@ class AdvancedChatAppGenerateTaskPipeline: } self._task_state.answer = annotation.content + elif isinstance(event, QueueWorkflowStartedEvent): + self._task_state.workflow_run_id = event.workflow_run_id elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: @@ -171,6 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: break elif isinstance(event, QueueWorkflowStartedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) + self._task_state.workflow_run_id = workflow_run.id response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -234,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline: if isinstance(event, QueueWorkflowFinishedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -389,7 +393,13 @@ class AdvancedChatAppGenerateTaskPipeline: :param workflow_run_id: workflow run id :return: """ - return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + if workflow_run: + # Because the workflow_run will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_run) + return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: """ @@ -397,7 +407,14 @@ class AdvancedChatAppGenerateTaskPipeline: :param workflow_node_execution_id: workflow node execution id :return: """ - return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) + if workflow_node_execution: + # Because the workflow_node_execution will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_node_execution) + return workflow_node_execution def _save_message(self) -> None: """ @@ -408,6 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 80ecdf7757..bc6e4bd800 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -48,7 +48,7 @@ class DirectAnswerNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variable_values, - output={ + outputs={ "answer": answer } ) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 05a784c221..19dac76631 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -33,6 +33,7 @@ from models.workflow import ( WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, + WorkflowType, ) node_classes = { @@ -268,7 +269,7 @@ class WorkflowEngineManager: # fetch last workflow_node_executions last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] if last_workflow_node_execution: - workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) + workflow_run.outputs = last_workflow_node_execution.outputs workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.total_tokens = workflow_run_state.total_tokens @@ -390,6 +391,7 @@ class WorkflowEngineManager: workflow_run_state=workflow_run_state, node=node, predecessor_node=predecessor_node, + callbacks=callbacks ) # add to workflow node executions @@ -412,6 +414,9 @@ class WorkflowEngineManager: ) raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") + # set end node output if in chat + self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) + # node run success self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, @@ -529,6 +534,32 @@ class WorkflowEngineManager: return workflow_node_execution + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + node_run_result: NodeRunResult): + """ + Set end node output if in chat + :param workflow_run_state: workflow run state + :param node: current node + :param node_run_result: node run result + :return: + """ + if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2] + if workflow_node_execution_before_end: + if workflow_node_execution_before_end.node_type == NodeType.LLM.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text') + elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer') + + return node_run_result + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str],