From e34497ded15764e579b6772386b21bee83528399 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 21 Aug 2024 17:25:26 +0800 Subject: [PATCH] fix: merge issues --- .../app/apps/advanced_chat/app_generator.py | 56 ------------------- api/core/app/apps/advanced_chat/app_runner.py | 27 ++++++--- .../apps/workflow/generate_task_pipeline.py | 2 - .../task_pipeline/workflow_cycle_manage.py | 4 +- api/core/workflow/nodes/start/start_node.py | 2 +- .../workflow/nodes/variable_assigner/node.py | 12 ++-- ...21501b_add_node_execution_id_into_node_.py | 4 +- .../graph_engine/test_graph_engine.py | 22 ++++---- .../core/workflow/nodes/answer/test_answer.py | 6 +- .../answer/test_answer_stream_processor.py | 10 ++-- .../nodes/iteration/test_iteration.py | 10 ++-- .../core/workflow/nodes/test_answer.py | 4 +- .../workflow/nodes/test_variable_assigner.py | 2 +- 13 files changed, 56 insertions(+), 105 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index f896658655..458bb5d2f6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -228,62 +228,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id ) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - # Create conversation variables if they don't exist. - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] - - session.commit() - - # Increment dialogue count. - conversation.dialogue_count += 1 - - conversation_id = conversation.id - conversation_dialogue_count = conversation.dialogue_count - db.session.commit() - db.session.refresh(conversation) - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - contexts.workflow_variable_pool.set(variable_pool) - # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), # type: ignore diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d2cd7fd4e1..33c3f7cec5 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -23,11 +23,11 @@ from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable +from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) @@ -124,6 +124,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): with Session(db.engine) as session: conversation_variables = session.scalars(stmt).all() if not conversation_variables: + # Create conversation variables if they don't exist. conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable @@ -131,16 +132,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): for variable in workflow.conversation_variables ] session.add_all(conversation_variables) - session.commit() - # Convert database entities to variables + # Convert database entities to variables. conversation_variables = [item.to_variable() for item in conversation_variables] + session.commit() + + # Increment dialogue count. + self.conversation.dialogue_count += 1 + + conversation_dialogue_count = self.conversation.dialogue_count + db.session.commit() + # Create a variable pool. system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: self.conversation.id, - SystemVariable.USER_ID: user_id, + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, } # init variable pool @@ -159,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): tenant_id=workflow.tenant_id, app_id=workflow.app_id, workflow_id=workflow.id, - workflow_type=workflow.type, + workflow_type=WorkflowType.value_of(workflow.type), graph=graph, graph_config=workflow.graph_dict, user_id=self.application_generate_entity.user_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index da72fe9434..1f0db7ff34 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -41,9 +41,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account from models.model import EndUser diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index caca6d00b2..a7b9872d45 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -33,7 +33,7 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -56,7 +56,7 @@ class WorkflowCycleManage: _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] def _handle_workflow_run_start(self) -> WorkflowRun: max_sequence = ( diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 1afc53f341..69cdec6a92 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence from typing import Any from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 8c2adcabb9..79b7bb1c26 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -19,23 +19,23 @@ class VariableAssignerNode(BaseNode): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: data = cast(VariableAssignerData, self.node_data) # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableAssignerNodeError('assigned variable not found') match data.write_mode: case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_variable = original_variable.model_copy(update={'value': income_value.value}) case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_value = original_variable.value + [income_value.value] @@ -49,11 +49,11 @@ class VariableAssignerNode(BaseNode): raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) if not conversation_id: raise VariableAssignerNodeError('conversation_id not found') update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py index 3048ebd053..1b148a669f 100644 --- a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -1,7 +1,7 @@ """add node_execution_id into node_executions Revision ID: 675b5321501b -Revises: 8782057ff0dc +Revises: 2dbe42621d96 Create Date: 2024-08-12 10:54:02.259331 """ @@ -12,7 +12,7 @@ import models as models # revision identifiers, used by Alembic. revision = '675b5321501b' -down_revision = '8782057ff0dc' +down_revision = '2dbe42621d96' branch_labels = None depends_on = None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 785299f327..aa341f065b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -3,7 +3,7 @@ from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, @@ -201,8 +201,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): ) variable_pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={ "query": "hi" }) @@ -363,10 +363,10 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): ) variable_pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}) graph_engine = GraphEngine( @@ -521,10 +521,10 @@ def test_run_branch(mock_close, mock_remove): ) variable_pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'hi', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'hi', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={ "uid": "takato" }) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 2483e576ec..fbcb209d07 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -57,8 +57,8 @@ def test_execute_answer(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'weather'], 'sunny') pool.add(['llm', 'text'], 'You are a helpful AI.') diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index f041d0395e..13b74d65ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, NodeRunStartedEvent, @@ -198,10 +198,10 @@ def test_process(): ) variable_pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}) answer_stream_processor = AnswerStreamProcessor( diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 1b559e7d65..ff46e62d1f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -5,7 +5,7 @@ from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -155,10 +155,10 @@ def test_run(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'dify', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: '1' + SystemVariableKey.QUERY: 'dify', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: '1' }, user_inputs={}, environment_variables=[]) pool.add(['pe', 'list_output'], ["dify-1", "dify-2"]) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 049f969916..fd2971eb57 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -90,8 +90,8 @@ def test_execute_answer(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'weather'], 'sunny') pool.add(['llm', 'text'], 'You are a helpful AI.') diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 83f7fbb609..61853cbcf5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -274,7 +274,7 @@ def test_clear_array(): }, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: list(node.run()) mock_run.assert_called_once()