From 2db67c410165f807af2ce5c2636397f2ac6a89e8 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 9 Mar 2024 19:05:48 +0800 Subject: [PATCH] refactor pipeline and remove node run run_args --- .../advanced_chat/generate_task_pipeline.py | 47 ++++++++---- .../apps/workflow/generate_task_pipeline.py | 48 +++++++++---- api/core/workflow/entities/variable_pool.py | 5 +- .../workflow/entities/workflow_entities.py | 4 +- api/core/workflow/nodes/base_node.py | 34 ++++++--- api/core/workflow/nodes/code/code_node.py | 45 ++++++------ .../nodes/direct_answer/direct_answer_node.py | 21 +++--- api/core/workflow/nodes/end/end_node.py | 71 ++++++++++--------- api/core/workflow/nodes/llm/llm_node.py | 16 ++++- api/core/workflow/nodes/start/start_node.py | 18 +++-- api/core/workflow/workflow_engine_manager.py | 6 +- 11 files changed, 201 insertions(+), 114 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 18bc9c8008..048b429304 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -55,6 +55,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} usage: LLMUsage @@ -64,8 +77,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -218,7 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -237,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -447,15 +460,21 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -472,12 +491,14 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -504,8 +525,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 721124c4c5..26e4769fa6 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -41,6 +41,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} @@ -49,8 +62,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_finished', @@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e84044dede..3868041a8f 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -19,14 +19,17 @@ class ValueType(Enum): class VariablePool: variables_mapping = {} + user_inputs: dict - def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: + def __init__(self, system_variables: dict[SystemVariable, Any], + user_inputs: dict) -> None: # system variables # for example: # { # 'query': 'abc', # 'files': [] # } + self.user_inputs = user_inputs for system_variable, value in system_variables.items(): self.append_variable('sys', [system_variable.value], value) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 6c2adfe0fb..768ad6a130 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -18,15 +18,13 @@ class WorkflowNodeAndResult: class WorkflowRunState: workflow: Workflow start_at: float - user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): self.workflow = workflow self.start_at = start_at - self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6720017d9f..3f2e806433 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -28,31 +28,23 @@ class BaseNode(ABC): self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ raise NotImplementedError - def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node entry :param variable_pool: variable pool - :param run_args: run args :return: """ - if variable_pool is None and run_args is None: - raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - try: result = self._run( - variable_pool=variable_pool, - run_args=run_args + variable_pool=variable_pool ) except Exception as e: # process unhandled exception @@ -77,6 +69,26 @@ class BaseNode(ABC): text=text ) + @classmethod + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: + """ + Extract variable selector to variable mapping + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping(node_data) + + @classmethod + @abstractmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + raise NotImplementedError + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 77bcccab21..a65edafbad 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,6 @@ from typing import Optional, Union, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -15,6 +16,7 @@ MAX_STRING_LENGTH = 1000 MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 + class CodeNode(BaseNode): _node_data_cls = CodeNodeData node_type = NodeType.CODE @@ -78,21 +80,15 @@ class CodeNode(BaseNode): } } - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run code :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(self._node_data_cls, node_data) - # SINGLE DEBUG NOT IMPLEMENTED YET - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - # Get code language code_language = node_data.code_language code = node_data.code @@ -134,7 +130,6 @@ class CodeNode(BaseNode): Check string :param value: value :param variable: variable - :param max_length: max length :return: """ if not isinstance(value, str): @@ -142,9 +137,9 @@ class CodeNode(BaseNode): if len(value) > MAX_STRING_LENGTH: raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') - + return value.replace('\x00', '') - + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ Check number @@ -157,13 +152,13 @@ class CodeNode(BaseNode): if value > MAX_NUMBER or value < MIN_NUMBER: raise ValueError(f'{variable} in input form is out of range.') - + if isinstance(value, float): value = round(value, MAX_PRECISION) return value - def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], prefix: str = '', depth: int = 1) -> dict: """ @@ -174,7 +169,7 @@ class CodeNode(BaseNode): """ if depth > MAX_DEPTH: raise ValueError("Depth limit reached, object too deep.") - + transformed_result = {} for output_name, output_config in output_schema.items(): if output_config.type == 'object': @@ -183,7 +178,7 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' ) - + transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, @@ -208,7 +203,7 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' @@ -227,12 +222,12 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' ) - + transformed_result[output_name] = [ self._check_string( value=value, @@ -242,5 +237,15 @@ class CodeNode(BaseNode): ] else: raise ValueError(f'Output type {output_config.type} is not supported.') - - return transformed_result \ No newline at end of file + + return transformed_result + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 971cbe536e..9193bab9ee 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,7 +1,8 @@ import time -from typing import Optional, cast +from typing import cast from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode): _node_data_cls = DirectAnswerNodeData node_type = NodeType.DIRECT_ANSWER - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - variable_values = {} for variable_selector in node_data.variables: value = variable_pool.get_variable_value( @@ -43,7 +39,7 @@ class DirectAnswerNode(BaseNode): # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) # todo sleep 0.01 + time.sleep(0.01) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode): "answer": answer } ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 62429e3ac2..65b0b86aa0 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,6 @@ -from typing import Optional, cast +from typing import cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -11,50 +12,54 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData node_type = NodeType.END - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) outputs_config = node_data.outputs - if variable_pool is not None: - outputs = None - if outputs_config: - if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: - plain_text_selector = outputs_config.plain_text_selector - if plain_text_selector: - outputs = { - 'text': variable_pool.get_variable_value( - variable_selector=plain_text_selector, - target_value_type=ValueType.STRING - ) - } - else: - outputs = { - 'text': '' - } - elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: - structured_variables = outputs_config.structured_variables - if structured_variables: - outputs = {} - for variable_selector in structured_variables: - variable_value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - outputs[variable_selector.variable] = variable_value - else: - outputs = {} - else: - raise ValueError("Not support single step debug.") + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index e3ae9fc00f..90a7755b85 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,5 +1,6 @@ from typing import Optional, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -10,12 +11,10 @@ class LLMNode(BaseNode): _node_data_cls = LLMNodeData node_type = NodeType.LLM - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -23,6 +22,17 @@ class LLMNode(BaseNode): pass + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} + + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index ce04031b04..2321e04bd4 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,7 @@ -from typing import Optional, cast +from typing import cast from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -12,12 +13,10 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData node_type = NodeType.START - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -25,7 +24,7 @@ class StartNode(BaseNode): variables = node_data.variables # Get cleaned inputs - cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -68,3 +67,12 @@ class StartNode(BaseNode): filtered_inputs[variable] = value.replace('\x00', '') if value else None return filtered_inputs + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index c5af015e87..0b96717de7 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -109,9 +109,9 @@ class WorkflowEngineManager: workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), - user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, + user_inputs=user_inputs ) ) @@ -292,9 +292,7 @@ class WorkflowEngineManager: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool, - run_args=workflow_run_state.user_inputs - if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node + variable_pool=workflow_run_state.variable_pool ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: