From 75f1355d4c742399f247a7dd0737512b6f1741db Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 23:34:23 +0800 Subject: [PATCH] add few workflow run codes --- api/commands.py | 2 +- api/core/app/app_config/entities.py | 1 + api/core/app/apps/advanced_chat/app_runner.py | 7 +- api/core/callback_handler/__init__.py | 0 .../std_out_callback_handler.py | 157 ------------------ .../workflow_event_trigger_callback.py | 45 +++++ api/core/workflow/callbacks/__init__.py | 0 api/core/workflow/callbacks/base_callback.py | 33 ++++ .../entities/base_node_data_entities.py | 7 + api/core/workflow/nodes/base_node.py | 43 ++--- api/core/workflow/nodes/start/entities.py | 27 +++ api/core/workflow/nodes/start/start_node.py | 19 ++- api/core/workflow/workflow_engine_manager.py | 96 ++++++++++- 13 files changed, 254 insertions(+), 183 deletions(-) create mode 100644 api/core/callback_handler/__init__.py delete mode 100644 api/core/callback_handler/std_out_callback_handler.py create mode 100644 api/core/callback_handler/workflow_event_trigger_callback.py create mode 100644 api/core/workflow/callbacks/__init__.py create mode 100644 api/core/workflow/callbacks/base_callback.py create mode 100644 api/core/workflow/entities/base_node_data_entities.py create mode 100644 api/core/workflow/nodes/start/entities.py diff --git a/api/commands.py b/api/commands.py index 73325620ee..376a394d1e 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation +from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index e155dc1c4d..6a521dfcc5 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -112,6 +112,7 @@ class VariableEntity(BaseModel): max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None + hint: Optional[str] = None class ExternalDataVariableEntity(BaseModel): diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 02d22072df..920adcfb79 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -10,12 +10,14 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import QueueStopEvent +from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message +from models.workflow import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -83,13 +85,16 @@ class AdvancedChatAppRunner(AppRunner): result_generator = workflow_engine_manager.run_workflow( app_model=app_record, workflow=workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, user=user, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, - } + }, + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) for result in result_generator: diff --git a/api/core/callback_handler/__init__.py b/api/core/callback_handler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py deleted file mode 100644 index 1f95471afb..0000000000 --- a/api/core/callback_handler/std_out_callback_handler.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import sys -from typing import Any, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - - -class DifyStdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - print_text("\n[on_chat_model_start]\n", color='blue') - for sub_messages in messages: - for sub_message in sub_messages: - print_text(str(sub_message) + "\n", color='blue') - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - print_text("\n[on_llm_start]\n", color='blue') - print_text(prompts[0] + "\n", color='blue') - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str( - response.llm_output) + "\n", color='blue') - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') - - def on_chain_start( - self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - chain_type = serialized['id'][-1] - print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - - def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink') - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_start] " + str(serialized), color='yellow') - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = action.tool_input - try: - action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - except ValueError: - thought = '' - - log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" - print_text("\n[on_agent_action]\n" + log + "\n", color='green') - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - print_text("\n[on_tool_end]\n", color='yellow') - if observation_prefix: - print_text(f"\n{observation_prefix}") - print_text(output, color='yellow') - if llm_prefix: - print_text(f"\n{llm_prefix}") - print_text("\n") - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow') - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - -class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py new file mode 100644 index 0000000000..2f81f27426 --- /dev/null +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -0,0 +1,45 @@ +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager): + self._queue_manager = queue_manager + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish_workflow_started( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish_workflow_finished( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish_node_started( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish_node_finished( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_callback.py new file mode 100644 index 0000000000..a564af498c --- /dev/null +++ b/api/core/workflow/callbacks/base_callback.py @@ -0,0 +1,33 @@ +from abc import abstractmethod + +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class BaseWorkflowCallback: + @abstractmethod + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py new file mode 100644 index 0000000000..32b93ea094 --- /dev/null +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -0,0 +1,7 @@ +from abc import ABC + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + pass diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a2751b346f..a95a232ae6 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,32 +1,21 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool class BaseNode: _node_type: NodeType + _node_data_cls: Type[BaseNodeData] - def __int__(self, node_config: dict) -> None: - self._node_config = node_config + def __init__(self, config: dict) -> None: + self._node_id = config.get("id") + if not self._node_id: + raise ValueError("Node ID is required.") - @abstractmethod - def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: - """ - Run node - :param variable_pool: variable pool - :param run_args: run args - :return: - """ - if variable_pool is None and run_args is None: - raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - - return self._run( - variable_pool=variable_pool, - run_args=run_args - ) + self._node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, @@ -39,6 +28,22 @@ class BaseNode: """ raise NotImplementedError + def run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node entry + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + if variable_pool is None and run_args is None: + raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") + + return self._run( + variable_pool=variable_pool, + run_args=run_args + ) + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py new file mode 100644 index 0000000000..25b27cf192 --- /dev/null +++ b/api/core/workflow/nodes/start/entities.py @@ -0,0 +1,27 @@ +from typing import Optional + +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType + + +class StartNodeData(BaseNodeData): + """ + - title (string) 节点标题 + - desc (string) optional 节点描述 + - type (string) 节点类型,固定为 start + - variables (array[object]) 表单变量列表 + - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) + - label (string) 控件展示标签名 + - variable (string) 变量 key + - max_length (int) 最大长度,适用于 text-input 和 paragraph + - default (string) optional 默认值 + - required (bool) optional是否必填,默认 false + - hint (string) optional 提示信息 + - options (array[string]) 选项值(仅 select 可用) + """ + type: str = NodeType.START.value + + title: str + desc: Optional[str] = None + variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8cce655728..014a146c93 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,22 @@ +from typing import Type, Optional + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.start.entities import StartNodeData class StartNode(BaseNode): - pass + _node_type = NodeType.START + _node_data_cls = StartNodeData + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + pass + diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8a23048705..afa4dbb321 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,8 @@ +import json from collections.abc import Generator from typing import Optional, Union +from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -17,7 +19,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole node_classes = { NodeType.START: StartNode, @@ -108,17 +110,103 @@ class WorkflowEngineManager: def run_workflow(self, app_model: App, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, - system_inputs: Optional[dict] = None) -> Generator: + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Generator: """ Run workflow :param app_model: App instance :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks + :return: + """ + # fetch workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + # init workflow run + workflow_run = self._init_workflow_run( + workflow=workflow, + triggered_from=triggered_from, + user=user, + user_inputs=user_inputs, + system_inputs=system_inputs + ) + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started(workflow_run) + + pass + + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :return: """ - # TODO - pass + try: + db.session.begin() + + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .for_update() \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **system_inputs}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by_id=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + except: + db.session.rollback() + raise + + return workflow_run + + def _get_entry_node(self, graph: dict) -> Optional[StartNode]: + """ + Get entry node + :param graph: workflow graph + :return: + """ + nodes = graph.get('nodes') + if not nodes: + return None + + for node_config in nodes.items(): + if node_config.get('type') == NodeType.START.value: + return StartNode(config=node_config) + + return None