From ea883b5e4806d7f0e482accf3e5ebfd62202475c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 15:43:55 +0800 Subject: [PATCH] add start, end, direct answer node --- .../entities/base_node_data_entities.py | 2 - api/core/workflow/entities/node_entities.py | 13 ++++- .../workflow/entities/variable_entities.py | 9 +++ .../workflow/entities/workflow_entities.py | 7 ++- api/core/workflow/nodes/base_node.py | 4 +- .../nodes/direct_answer/direct_answer_node.py | 51 ++++++++++++++++- .../workflow/nodes/direct_answer/entities.py | 10 ++++ api/core/workflow/nodes/end/end_node.py | 57 ++++++++++++++++++- api/core/workflow/nodes/end/entities.py | 43 ++++++++++++++ api/core/workflow/nodes/llm/entities.py | 8 +++ api/core/workflow/nodes/llm/llm_node.py | 21 ++++++- api/core/workflow/nodes/start/entities.py | 16 +----- api/core/workflow/nodes/start/start_node.py | 56 ++++++++++++++++-- api/core/workflow/workflow_engine_manager.py | 8 ++- 14 files changed, 274 insertions(+), 31 deletions(-) create mode 100644 api/core/workflow/entities/variable_entities.py create mode 100644 api/core/workflow/nodes/direct_answer/entities.py create mode 100644 api/core/workflow/nodes/llm/entities.py diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index afa6ddff04..fc6ee231ff 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -5,7 +5,5 @@ from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): - type: str - title: str desc: Optional[str] = None diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index af539692ef..263172da31 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel @@ -46,6 +46,15 @@ class SystemVariable(Enum): CONVERSATION = 'conversation' +class NodeRunMetadataKey(Enum): + """ + Node Run Metadata Key. + """ + TOTAL_TOKENS = 'total_tokens' + TOTAL_PRICE = 'total_price' + CURRENCY = 'currency' + + class NodeRunResult(BaseModel): """ Node Run Result. @@ -55,7 +64,7 @@ class NodeRunResult(BaseModel): inputs: Optional[dict] = None # node inputs process_data: Optional[dict] = None # process data outputs: Optional[dict] = None # node outputs - metadata: Optional[dict] = None # node metadata + metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py new file mode 100644 index 0000000000..19d9af2a61 --- /dev/null +++ b/api/core/workflow/entities/variable_entities.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class VariableSelector(BaseModel): + """ + Variable Selector. + """ + variable: str + value_selector: list[str] diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 0d78e4c4f1..8c15cb95cd 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun class WorkflowRunState: workflow_run: WorkflowRun start_at: float + user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_node_executions: list[WorkflowNodeExecution] = [] - def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: + def __init__(self, workflow_run: WorkflowRun, + start_at: float, + user_inputs: dict, + variable_pool: VariablePool) -> None: self.workflow_run = workflow_run self.start_at = start_at + self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 1ff05f9f4e..6720017d9f 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecutionStatus -class BaseNode: +class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index c6013974b8..80ecdf7757 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,5 +1,54 @@ +import time +from typing import Optional, cast + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData +from models.workflow import WorkflowNodeExecutionStatus class DirectAnswerNode(BaseNode): - pass + _node_data_cls = DirectAnswerNodeData + node_type = NodeType.DIRECT_ANSWER + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + if variable_pool is None and run_args: + raise ValueError("Not support single step debug.") + + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + answer = template_parser.format(variable_values) + + # publish answer as stream + for word in answer: + self.publish_text_chunk(word) + time.sleep(0.01) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variable_values, + output={ + "answer": answer + } + ) diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/direct_answer/entities.py new file mode 100644 index 0000000000..e7c11e3c4d --- /dev/null +++ b/api/core/workflow/nodes/direct_answer/entities.py @@ -0,0 +1,10 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class DirectAnswerNodeData(BaseNodeData): + """ + DirectAnswer Node Data. + """ + variables: list[VariableSelector] = [] + answer: str diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index f9aea89af7..62429e3ac2 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,60 @@ +from typing import Optional, cast + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs +from models.workflow import WorkflowNodeExecutionStatus class EndNode(BaseNode): - pass + _node_data_cls = EndNodeData + node_type = NodeType.END + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + outputs_config = node_data.outputs + + if variable_pool is not None: + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} + else: + raise ValueError("Not support single step debug.") + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs + ) diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index 045e7effc4..32212ae7fa 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,4 +1,10 @@ from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class EndNodeOutputType(Enum): @@ -23,3 +29,40 @@ class EndNodeOutputType(Enum): if output_type.value == value: return output_type raise ValueError(f'invalid output type value {value}') + + +class EndNodeDataOutputs(BaseModel): + """ + END Node Data Outputs. + """ + class OutputType(Enum): + """ + Output Types. + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') + + type: OutputType = OutputType.NONE + plain_text_selector: Optional[list[str]] = None + structured_variables: Optional[list[VariableSelector]] = None + + +class EndNodeData(BaseNodeData): + """ + END Node Data. + """ + outputs: Optional[EndNodeDataOutputs] = None diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py new file mode 100644 index 0000000000..bd499543d9 --- /dev/null +++ b/api/core/workflow/nodes/llm/entities.py @@ -0,0 +1,8 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class LLMNodeData(BaseNodeData): + """ + LLM Node Data. + """ + pass diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 1c7277e942..e3ae9fc00f 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,9 +1,28 @@ -from typing import Optional +from typing import Optional, cast +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.llm.entities import LLMNodeData class LLMNode(BaseNode): + _node_data_cls = LLMNodeData + node_type = NodeType.LLM + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + pass + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 64687db042..0bd5f203bf 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,23 +1,9 @@ from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType class StartNodeData(BaseNodeData): """ - - title (string) 节点标题 - - desc (string) optional 节点描述 - - type (string) 节点类型,固定为 start - - variables (array[object]) 表单变量列表 - - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) - - label (string) 控件展示标签名 - - variable (string) 变量 key - - max_length (int) 最大长度,适用于 text-input 和 paragraph - - default (string) optional 默认值 - - required (bool) optional是否必填,默认 false - - hint (string) optional 提示信息 - - options (array[string]) 选项值(仅 select 可用) + Start Node Data """ - type: str = NodeType.START.value - variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 74d8541436..ce04031b04 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,9 +1,11 @@ -from typing import Optional +from typing import Optional, cast -from core.workflow.entities.node_entities import NodeType +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData +from models.workflow import WorkflowNodeExecutionStatus class StartNode(BaseNode): @@ -11,12 +13,58 @@ class StartNode(BaseNode): node_type = NodeType.START def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node :param variable_pool: variable pool :param run_args: run args :return: """ - pass + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + variables = node_data.variables + # Get cleaned inputs + cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=cleaned_inputs, + outputs=cleaned_inputs + ) + + def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"Input form variable {variable} is required") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8ab0eb4802..5423546957 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,6 +3,7 @@ import time from datetime import datetime from typing import Optional, Union +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -141,6 +142,7 @@ class WorkflowEngineManager: workflow_run_state = WorkflowRunState( workflow_run=workflow_run, start_at=time.perf_counter(), + user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, ) @@ -399,7 +401,9 @@ class WorkflowEngineManager: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool + variable_pool=workflow_run_state.variable_pool, + run_args=workflow_run_state.user_inputs + if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: @@ -492,7 +496,7 @@ class WorkflowEngineManager: workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(result.metadata) + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) workflow_node_execution.finished_at = datetime.utcnow() db.session.commit()