From 460a825ef14cdea8d7e73b3f4b88ad00ca06a6d4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 18 Jul 2025 10:08:51 +0800 Subject: [PATCH] refactor: decouple Node and NodeData (#22581) Signed-off-by: -LAN- Co-authored-by: QuantumGhost --- .../app/apps/advanced_chat/app_generator.py | 3 +- api/core/app/apps/agent_chat/app_generator.py | 3 +- api/core/app/apps/base_app_queue_manager.py | 4 - api/core/app/apps/base_app_runner.py | 2 +- api/core/app/apps/chat/app_generator.py | 3 +- api/core/app/apps/completion/app_generator.py | 3 +- api/core/app/apps/exc.py | 2 + .../app/apps/message_based_app_generator.py | 3 +- .../apps/message_based_app_queue_manager.py | 3 +- api/core/app/apps/workflow/app_generator.py | 3 +- .../app/apps/workflow/app_queue_manager.py | 3 +- api/core/prompt/simple_prompt_transform.py | 15 +- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/core/workflow/errors.py | 8 +- api/core/workflow/graph_engine/__init__.py | 3 +- .../workflow/graph_engine/graph_engine.py | 168 ++++---- api/core/workflow/nodes/agent/agent_node.py | 375 ++++++++++++++---- api/core/workflow/nodes/agent/exc.py | 124 ++++++ api/core/workflow/nodes/answer/answer_node.py | 47 ++- api/core/workflow/nodes/base/entities.py | 4 +- api/core/workflow/nodes/base/node.py | 119 +++--- api/core/workflow/nodes/code/code_node.py | 59 ++- .../workflow/nodes/document_extractor/node.py | 47 ++- api/core/workflow/nodes/end/end_node.py | 34 +- api/core/workflow/nodes/enums.py | 4 - api/core/workflow/nodes/http_request/node.py | 60 ++- .../workflow/nodes/if_else/if_else_node.py | 46 ++- .../nodes/iteration/iteration_node.py | 121 +++--- .../nodes/iteration/iteration_start_node.py | 32 +- .../nodes/knowledge_retrieval/entities.py | 17 +- .../knowledge_retrieval_node.py | 138 +++++-- api/core/workflow/nodes/list_operator/node.py | 59 ++- api/core/workflow/nodes/llm/entities.py | 4 +- api/core/workflow/nodes/llm/node.py | 243 ++++++++---- api/core/workflow/nodes/loop/loop_end_node.py | 32 +- api/core/workflow/nodes/loop/loop_node.py | 93 +++-- .../workflow/nodes/loop/loop_start_node.py | 32 +- .../parameter_extractor_node.py | 54 ++- .../question_classifier_node.py | 115 ++++-- api/core/workflow/nodes/start/start_node.py | 32 +- .../template_transform_node.py | 47 ++- api/core/workflow/nodes/tool/tool_node.py | 159 ++++---- .../variable_aggregator_node.py | 36 +- .../nodes/variable_assigner/v1/node.py | 60 ++- .../nodes/variable_assigner/v2/node.py | 46 ++- api/core/workflow/workflow_entry.py | 33 +- api/services/workflow_service.py | 20 +- .../workflow/nodes/__mock/model.py | 2 +- .../workflow/nodes/test_code.py | 20 +- .../workflow/nodes/test_http.py | 8 +- .../workflow/nodes/test_llm.py | 197 ++++----- .../nodes/test_parameter_extractor.py | 4 +- .../workflow/nodes/test_template_transform.py | 1 + .../workflow/nodes/test_tool.py | 4 +- .../core/workflow/nodes/answer/test_answer.py | 21 +- .../http_request/test_http_request_node.py | 42 +- .../nodes/iteration/test_iteration.py | 159 ++++---- .../core/workflow/nodes/llm/test_node.py | 62 ++- .../core/workflow/nodes/test_answer.py | 21 +- .../nodes/test_document_extractor_node.py | 8 +- .../core/workflow/nodes/test_if_else.py | 151 +++---- .../core/workflow/nodes/test_list_operator.py | 11 +- .../workflow/nodes/tool/test_tool_node.py | 11 +- .../v1/test_variable_assigner_v1.py | 69 ++-- .../v2/test_variable_assigner_v2.py | 140 ++++--- 65 files changed, 2305 insertions(+), 1146 deletions(-) create mode 100644 api/core/app/apps/exc.py create mode 100644 api/core/workflow/nodes/agent/exc.py diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b8f5ebe27..bd5ad9c51b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index edea6199d3..8665bc9d11 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 0ba33fbe0d..9da0bae56a 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -169,7 +169,3 @@ class AppQueueManager: raise TypeError( "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." ) - - -class GenerateTaskStoppedError(Exception): - pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 428db607fa..6e8c261a6a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -118,7 +118,7 @@ class AppRunner: else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a28c106ce9..0c76cc39ae 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -11,10 +11,11 @@ from configs import dify_config from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 966a6f1d66..195e7e2e3d 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,10 +10,11 @@ from pydantic import ValidationError from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/exc.py b/api/core/app/apps/exc.py new file mode 100644 index 0000000000..4187118b9b --- /dev/null +++ b/api/core/app/apps/exc.py @@ -0,0 +1,2 @@ +class GenerateTaskStoppedError(Exception): + pass diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index e84d59209d..85fafe6980 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -6,7 +6,8 @@ from typing import Optional, Union, cast from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 363c3c82bb..8507f23f17 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 2f9632e97d..6f560b3253 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -13,7 +13,8 @@ import contexts from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 349b8eb51b..40fc03afb7 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 47808928f7..e19c6419ca 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" - @classmethod - def value_of(cls, value: str) -> "ModelMode": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - prompt_file_contents: dict[str, Any] = {} @@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5c0360b064..3d0f0f97bc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1137,7 +1137,7 @@ class DatasetRetrieval: def _get_prompt_template( self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str ): - model_mode = ModelMode.value_of(mode) + model_mode = ModelMode(mode) input_text = query prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index bd4ccc1072..594bb2b32e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_instance: BaseNode, error: str): - self.node_instance = node_instance - self.error = error - super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") + def __init__(self, node: BaseNode, err_msg: str): + self._node = node + self._error = err_msg + super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index 2fee3d7fad..12e1de464b 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -1,3 +1,4 @@ from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState +from .graph_engine import GraphEngine -__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] +__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 5a2915e2d3..b315129763 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -12,7 +12,7 @@ from typing import Any, Optional, cast from flask import Flask, current_app from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom @@ -260,12 +258,16 @@ class GraphEngine: # convert to specific node node_type = NodeType(node_config.get("data", {}).get("type")) node_version = node_config.get("data", {}).get("version", "1") + + # Import here to avoid circular import + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None # init workflow run state - node_instance = node_cls( # type: ignore + node = node_cls( id=route_node_state.id, config=node_config, graph_init_params=self.init_params, @@ -274,11 +276,11 @@ class GraphEngine: previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) + node.init_node_data(node_config.get("data", {})) try: # run node generator = self._run_node( - node_instance=node_instance, + node=node, route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -306,16 +308,16 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( error=str(e), - id=node_instance.id, + id=node.id, node_id=next_node_id, node_type=node_type, - node_data=node_instance.node_data, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) raise e @@ -337,7 +339,7 @@ class GraphEngine: edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION - and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and edge.run_condition is None ): break @@ -413,8 +415,8 @@ class GraphEngine: next_node_id = final_node_id elif ( - node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - and node_instance.should_continue_on_error + node.continue_on_error + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION ): break @@ -597,7 +599,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -611,29 +613,29 @@ class GraphEngine: # trigger node run start event agent_strategy = ( AgentNodeStrategyInit( - name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, - icon=cast(AgentNode, node_instance).agent_strategy_icon, + name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name, + icon=cast(AgentNode, node).agent_strategy_icon, ) - if node_instance.node_type == NodeType.AGENT + if node.type_ == NodeType.AGENT else None ) yield NodeRunStartedEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, - node_version=node_instance.version(), + node_version=node.version(), ) - max_retries = node_instance.node_data.retry_config.max_retries - retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + max_retries = node.retry_config.max_retries + retry_interval = node.retry_config.retry_interval_seconds retries = 0 should_continue_retry = True while should_continue_retry and retries <= max_retries: @@ -642,7 +644,7 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - event_stream = node_instance.run() + event_stream = node.run() for event in event_stream: if isinstance(event, GraphEngineEvent): # add parallel info to iteration event @@ -658,21 +660,21 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries - and node_instance.node_type == NodeType.HTTP_REQUEST + and node.type_ == NodeType.HTTP_REQUEST and run_result.outputs - and not node_instance.should_continue_on_error + and not node.continue_on_error ): run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED - if node_instance.should_retry and retries < max_retries: + if node.retry and retries < max_retries: retries += 1 route_node_state.node_run_result = run_result yield NodeRunRetryEvent( id=str(uuid.uuid4()), - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, @@ -680,17 +682,17 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, - node_version=node_instance.version(), + node_version=node.version(), ) time.sleep(retry_interval) break route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: + if node.continue_on_error: # if run failed, handle error run_result = self._handle_continue_on_error( - node_instance, + node, event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, @@ -701,44 +703,44 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False else: yield NodeRunFailedEvent( error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if ( - node_instance.should_continue_on_error - and self.graph.edge_mapping.get(node_instance.node_id) - and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH + node.continue_on_error + and self.graph.edge_mapping.get(node.node_id) + and node.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get( @@ -758,7 +760,7 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) @@ -783,26 +785,26 @@ class GraphEngine: run_result.metadata = metadata_dict yield NodeRunSucceededEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False break elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), chunk_content=event.chunk_content, from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, @@ -810,14 +812,14 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), retriever_resources=event.retriever_resources, context=event.context, route_node_state=route_node_state, @@ -825,7 +827,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -833,20 +835,20 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) return except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") + logger.exception(f"Node {node.title} run failed") raise e def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): @@ -886,22 +888,14 @@ class GraphEngine: def _handle_continue_on_error( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool, handle_exceptions: list[str] = [], ) -> NodeRunResult: - """ - handle continue on error when self._should_continue_on_error is True - - - :param error_result (NodeRunResult): error run result - :param variable_pool (VariablePool): variable pool - :return: excption run result - """ # add error message and error type to variable pool - variable_pool.add([node_instance.node_id, "error_message"], error_result.error) - variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + variable_pool.add([node.node_id, "error_message"], error_result.error) + variable_pool.add([node.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions handle_exceptions.append(error_result.error or "") node_error_args: dict[str, Any] = { @@ -909,21 +903,21 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy, }, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": error_result.error, "error_type": error_result.error_type, }, ) - elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: - if self.graph.edge_mapping.get(node_instance.node_id): + elif node.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node.node_id): node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED return NodeRunResult( **node_error_args, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ce67197a58..a4616eda69 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,5 +1,4 @@ import json -import uuid from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast @@ -11,8 +10,10 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.agent.strategy.plugin import PluginAgentStrategy +from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError @@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.tool_manager import ToolManager -from core.variables.segments import StringSegment +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayFileSegment, StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import RunCompletedEvent -from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy +from models import ToolFile from models.model import Conversation +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .exc import ( + AgentInputTypeError, + AgentInvocationError, + AgentMessageTransformError, + AgentVariableNotFoundError, + AgentVariableTypeError, + ToolFileNotFoundError, +) -class AgentNode(ToolNode): +class AgentNode(BaseNode): """ Agent Node """ - _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + _node_data: AgentNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AgentNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data @classmethod def version(cls) -> str: return "1" def _run(self) -> Generator: - """ - Run the agent node - """ - node_data = cast(AgentNodeData, self.node_data) - try: strategy = get_plugin_agent_strategy( tenant_id=self.tenant_id, - agent_strategy_provider_name=node_data.agent_strategy_provider_name, - agent_strategy_name=node_data.agent_strategy_name, + agent_strategy_provider_name=self._node_data.agent_strategy_provider_name, + agent_strategy_name=self._node_data.agent_strategy_name, ) except Exception as e: yield RunCompletedEvent( @@ -81,13 +112,13 @@ class AgentNode(ToolNode): parameters = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, strategy=strategy, ) parameters_for_log = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, for_log=True, strategy=strategy, ) @@ -105,59 +136,39 @@ class AgentNode(ToolNode): credentials=credentials, ) except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to invoke agent: {str(e)}", + error=str(error), ) ) return try: - # convert tool messages - agent_thoughts: list = [] - - thought_log_message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LOG, - message=ToolInvokeMessage.LogMessage( - id=str(uuid.uuid4()), - label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}", - parent_id=None, - error=None, - status=ToolInvokeMessage.LogMessage.LogStatus.START, - data={ - "strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, - "parameters": parameters_for_log, - "thought_process": "Agent strategy execution started", - }, - metadata={ - "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, - }, - ), - ) - - def enhanced_message_stream(): - yield thought_log_message - - yield from message_stream - yield from self._transform_message( - message_stream, - { + messages=message_stream, + tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, }, - parameters_for_log, - agent_thoughts, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_type=self.type_, + node_id=self.node_id, + node_execution_id=self.id, ) except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to transform agent message: {str(e)}", + error=str(transform_error), ) ) @@ -194,7 +205,7 @@ class AgentNode(ToolNode): if agent_input.type == "variable": variable = variable_pool.get(agent_input.value) # type: ignore if variable is None: - raise ValueError(f"Variable {agent_input.value} does not exist") + raise AgentVariableNotFoundError(str(agent_input.value)) parameter_value = variable.value elif agent_input.type in {"mixed", "constant"}: # variable_pool.convert_template expects a string template, @@ -216,7 +227,7 @@ class AgentNode(ToolNode): except json.JSONDecodeError: parameter_value = parameter_value else: - raise ValueError(f"Unknown agent input type '{agent_input.type}'") + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -259,7 +270,7 @@ class AgentNode(ToolNode): ) extra = tool.get("extra", {}) - runtime_variable_pool = variable_pool if self.node_data.version != "1" else None + runtime_variable_pool = variable_pool if self._node_data.version != "1" else None tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) @@ -343,19 +354,14 @@ class AgentNode(ToolNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: BaseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(AgentNodeData, node_data) + # Create typed NodeData from dict + typed_node_data = AgentNodeData.model_validate(node_data) + result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - input = node_data.agent_parameters[parameter_name] + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] if input.type in ["mixed", "constant"]: selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: @@ -380,7 +386,7 @@ class AgentNode(ToolNode): plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self.node_data).agent_strategy_provider_name + == cast(AgentNodeData, self._node_data).agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: @@ -448,3 +454,236 @@ class AgentNode(ToolNode): return tools else: return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage: LLMUsage | None = None + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == NodeType.AGENT: + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(msg_metadata) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + if message.message.json_object is not None: + json.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, File) + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.id == agent_log.id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any]] = [] + + # Step 1: append each agent log as its own dict. + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json: + json_output.extend(json) + else: + json_output.append({"data": []}) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py new file mode 100644 index 0000000000..d5955bdd7d --- /dev/null +++ b/api/core/workflow/nodes/agent/exc.py @@ -0,0 +1,124 @@ +from typing import Optional + + +class AgentNodeError(Exception): + """Base exception for all agent node errors.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +class AgentStrategyError(AgentNodeError): + """Exception raised when there's an error with the agent strategy.""" + + def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + self.strategy_name = strategy_name + self.provider_name = provider_name + super().__init__(message) + + +class AgentStrategyNotFoundError(AgentStrategyError): + """Exception raised when the specified agent strategy is not found.""" + + def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + super().__init__( + f"Agent strategy '{strategy_name}' not found" + + (f" for provider '{provider_name}'" if provider_name else ""), + strategy_name, + provider_name, + ) + + +class AgentInvocationError(AgentNodeError): + """Exception raised when there's an error invoking the agent.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentParameterError(AgentNodeError): + """Exception raised when there's an error with agent parameters.""" + + def __init__(self, message: str, parameter_name: Optional[str] = None): + self.parameter_name = parameter_name + super().__init__(message) + + +class AgentVariableError(AgentNodeError): + """Exception raised when there's an error with variables in the agent node.""" + + def __init__(self, message: str, variable_name: Optional[str] = None): + self.variable_name = variable_name + super().__init__(message) + + +class AgentVariableNotFoundError(AgentVariableError): + """Exception raised when a variable is not found in the variable pool.""" + + def __init__(self, variable_name: str): + super().__init__(f"Variable '{variable_name}' does not exist", variable_name) + + +class AgentInputTypeError(AgentNodeError): + """Exception raised when an unknown agent input type is encountered.""" + + def __init__(self, input_type: str): + super().__init__(f"Unknown agent input type '{input_type}'") + + +class ToolFileError(AgentNodeError): + """Exception raised when there's an error with a tool file.""" + + def __init__(self, message: str, file_id: Optional[str] = None): + self.file_id = file_id + super().__init__(message) + + +class ToolFileNotFoundError(ToolFileError): + """Exception raised when a tool file is not found.""" + + def __init__(self, file_id: str): + super().__init__(f"Tool file '{file_id}' does not exist", file_id) + + +class AgentMessageTransformError(AgentNodeError): + """Exception raised when there's an error transforming agent messages.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentModelError(AgentNodeError): + """Exception raised when there's an error with the model used by the agent.""" + + def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + self.model_name = model_name + self.provider = provider + super().__init__(message) + + +class AgentMemoryError(AgentNodeError): + """Exception raised when there's an error with the agent's memory.""" + + def __init__(self, message: str, conversation_id: Optional[str] = None): + self.conversation_id = conversation_id + super().__init__(message) + + +class AgentVariableTypeError(AgentNodeError): + """Exception raised when a variable has an unexpected type.""" + + def __init__( + self, + message: str, + variable_name: Optional[str] = None, + expected_type: Optional[str] = None, + actual_type: Optional[str] = None, + ): + self.variable_name = variable_name + self.expected_type = expected_type + self.actual_type = actual_type + super().__init__(message) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 38c2bcbdf5..84bbabca73 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import ( VarGenerateRouteChunk, ) from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -class AnswerNode(BaseNode[AnswerNodeData]): - _node_data_cls = AnswerNodeData +class AnswerNode(BaseNode): _node_type = NodeType.ANSWER + _node_data: AnswerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AnswerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]): :return: """ # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data) answer = "" files = [] @@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: AnswerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) + # Create typed NodeData from dict + typed_node_data = AnswerNodeData.model_validate(node_data) + + variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index d853eb71be..dcfed5eed2 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -122,13 +122,13 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + version: str = "1" error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None - version: str = "1" retry_config: RetryConfig = RetryConfig() @property - def default_value_dict(self): + def default_value_dict(self) -> dict[str, Any]: if self.default_value: return {item.key: item.value for item in self.default_value} return {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 6973401429..fb5ec55453 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,28 +1,22 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from .entities import BaseNodeData - if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine.entities.event import InNodeEvent - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) -GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) - -class BaseNode(Generic[GenericNodeData]): - _node_data_cls: type[GenericNodeData] +class BaseNode: _node_type: ClassVar[NodeType] def __init__( @@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]): self.node_id = node_id - node_data = self._node_data_cls.model_validate(config.get("data", {})) - self.node_data = node_data + @abstractmethod + def init_node_data(self, data: Mapping[str, Any]) -> None: ... @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: @@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]): if not node_id: raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - node_data = cls._node_data_cls(**config.get("data", {})) + # Pass raw dict data instead of creating NodeData instance data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) + graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) ) return data @@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: GenericNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ return {} @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return {} @property - def node_type(self) -> NodeType: - """ - Get node type - :return: - """ + def type_(self) -> NodeType: return self._node_type @classmethod @@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @property - def should_continue_on_error(self) -> bool: - """judge if should continue on error - - Returns: - bool: if should continue on error - """ - return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + def continue_on_error(self) -> bool: + return False @property - def should_retry(self) -> bool: - """judge if should retry + def retry(self) -> bool: + return False - Returns: - bool: if should retry - """ - return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE + # Abstract methods that subclasses must implement to provide access + # to BaseNodeData properties in a type-safe way + + @abstractmethod + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + ... + + @abstractmethod + def _get_retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + ... + + @abstractmethod + def _get_title(self) -> str: + """Get the node title.""" + ... + + @abstractmethod + def _get_description(self) -> Optional[str]: + """Get the node description.""" + ... + + @abstractmethod + def _get_default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + ... + + @abstractmethod + def get_base_node_data(self) -> BaseNodeData: + """Get the BaseNodeData object for this node.""" + ... + + # Public interface properties that delegate to abstract methods + @property + def error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + return self._get_error_strategy() + + @property + def retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + return self._get_retry_config() + + @property + def title(self) -> str: + """Get the node title.""" + return self._get_title() + + @property + def description(self) -> Optional[str]: + """Get the node description.""" + return self._get_description() + + @property + def default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + return self._get_default_value_dict() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 1adabf7247..fdf3932827 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .exc import ( CodeNodeError, @@ -21,10 +22,32 @@ from .exc import ( ) -class CodeNode(BaseNode[CodeNodeData]): - _node_data_cls = CodeNodeData +class CodeNode(BaseNode): _node_type = NodeType.CODE + _node_data: CodeNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = CodeNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]): def _run(self) -> NodeRunResult: # Get code language - code_language = self.node_data.code_language - code = self.node_data.code + code_language = self._node_data.code_language + code = self._node_data.code # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if isinstance(variable, ArrayFileSegment): @@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self._node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: CodeNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = CodeNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 8e6150f9cc..ab5964ebd4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast import chardet import docx @@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) -class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): +class DocumentExtractorNode(BaseNode): """ Extracts text content from various file types. Supports plain text, PDF, and DOC/DOCX files. """ - _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + _node_data: DocumentExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = DocumentExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" def _run(self): - variable_selector = self.node_data.variable_selector + variable_selector = self._node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) if variable is None: @@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: DocumentExtractorNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - return {node_id + ".files": node_data.variable_selector} + # Create typed NodeData from dict + typed_node_data = DocumentExtractorNodeData.model_validate(node_data) + + return {node_id + ".files": typed_node_data.variable_selector} def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 17a0b3adeb..f86f2e8129 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,14 +1,40 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType -class EndNode(BaseNode[EndNodeData]): - _node_data_cls = EndNodeData +class EndNode(BaseNode): _node_type = NodeType.END + _node_data: EndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = EndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]): Run node :return: """ - output_variables = self.node_data.outputs + output_variables = self._node_data.outputs outputs = {} for variable_selector in output_variables: diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..7cf9ab9107 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum): class FailBranchSourceHandle(StrEnum): FAILED = "fail-branch" SUCCESS = "success-branch" - - -CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] -RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 971e0f73e7..6799d5c63c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser from factories import file_factory @@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) -class HttpRequestNode(BaseNode[HttpRequestNodeData]): - _node_data_cls = HttpRequestNodeData +class HttpRequestNode(BaseNode): _node_type = NodeType.HTTP_REQUEST + _node_data: HttpRequestNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = HttpRequestNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { @@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): process_data = {} try: http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), + node_data=self._node_data, + timeout=self._get_request_timeout(self._node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, ) @@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.should_continue_on_error or self.should_retry): + if not response.response.is_success and (self.continue_on_error or self.retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ @@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: HttpRequestNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = HttpRequestNodeData.model_validate(node_data) + selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) + if typed_node_data.body: + body_type = typed_node_data.body.type + data = typed_node_data.body.data match body_type: case "binary": if len(data) != 1: @@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): files.append(file) return ArrayFileSegment(value=files) + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 22b748030c..86e703dc68 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, Optional from typing_extensions import deprecated @@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -class IfElseNode(BaseNode[IfElseNodeData]): - _node_data_cls = IfElseNodeData +class IfElseNode(BaseNode): _node_type = NodeType.IF_ELSE + _node_data: IfElseNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IfElseNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: + if self._node_data.cases: + for case in self._node_data.cases: input_conditions, group_result, final_result = condition_processor.process_conditions( variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions, @@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): input_conditions, group_result, final_result = _should_not_use_old_function( condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", + conditions=self._node_data.conditions or [], + operator=self._node_data.logical_operator or "and", ) selected_case_id = "true" if final_result else "false" @@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IfElseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = IfElseNodeData.model_validate(node_data) + var_mapping: dict[str, list[str]] = {} - for case in node_data.cases or []: + for case in typed_node_data.cases or []: for condition in case.conditions: key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) var_mapping[key] = condition.variable_selector diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 8b566c83cd..5842c8d64b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment @@ -56,14 +57,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class IterationNode(BaseNode[IterationNodeData]): +class IterationNode(BaseNode): """ Iteration Node. """ - _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + _node_data: IterationNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: return { @@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]): """ Run the node. """ - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") @@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]): graph_config = self.graph_config - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") - root_node_id = self.node_data.start_node_id + root_node_id = self._node_data.start_node_id # init graph iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) @@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunStartedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"iterator_length": len(iterator_list_value)}, @@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=0, pre_iteration_output=None, duration=None, @@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]): iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: - if self.node_data.is_parallel: + if self._node_data.is_parallel: futures: list[Future] = [] q: Queue = Queue() thread_pool = GraphEngineThreadPool( - max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT + max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT ) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( @@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph=iteration_graph, iter_run_map=iter_run_map, ) - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] # Flatten the list of lists @@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IterationNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = IterationNodeData.model_validate(node_data) + variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, + f"{node_id}.input_selector": typed_node_data.iterator_selector, } # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") @@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ if not isinstance(event, BaseNodeEvent): return event - if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent): event.parallel_mode_run_id = parallel_mode_run_id iter_metadata = { @@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]): elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]): event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id ) if isinstance(event, NodeRunFailedEvent): - if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.remove([node_id]) # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]): return yield metadata_event - current_output_segment = variable_pool.get(self.node_data.output_selector) + current_output_segment = variable_pool.get(self._node_data.output_selector) if current_output_segment is None: raise IterationNodeError("iteration output selector not found") current_iteration_output = current_output_segment.value @@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=current_iteration_output or None, @@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": None}, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 9900aa225d..b82c29291a 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData -class IterationStartNode(BaseNode[IterationStartNodeData]): +class IterationStartNode(BaseNode): """ Iteration Start Node. """ - _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + _node_data: IterationStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 19bdee4fe2..e9122b1eec 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,10 +1,10 @@ from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Literal, Optional from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): @@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel): weights: Optional[WeightedScoreConfig] = None -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} - - class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. @@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None + metadata_model_config: ModelConfig metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index f05d93d83e..4e9a38f552 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast @@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.entities.message_entities import ( + PromptMessageRole, +) +from core.model_runtime.entities.model_entities import ( + ModelFeature, + ModelType, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables import StringSegment +from core.variables import ( + StringSegment, +) from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, +) from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_2, @@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from services.feature_service import FeatureService -from .entities import KnowledgeRetrievalNodeData, ModelConfig +from .entities import KnowledgeRetrievalNodeData from .exc import ( InvalidModelTypeError, KnowledgeRetrievalNodeError, @@ -56,6 +68,10 @@ from .exc import ( ModelQuotaExceededError, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -67,18 +83,76 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(LLMNode): - _node_data_cls = KnowledgeRetrievalNodeData # type: ignore +class KnowledgeRetrievalNode(BaseNode): _node_type = NodeType.KNOWLEDGE_RETRIEVAL + _node_data: KnowledgeRetrievalNodeData + + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = KnowledgeRetrievalNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self) -> NodeRunResult: # type: ignore - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=node_data, query=query) + results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -435,20 +509,15 @@ class KnowledgeRetrievalNode(LLMNode): # get all metadata field metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] - # get metadata model config - metadata_model_config = node_data.metadata_model_config - if metadata_model_config is None: - raise ValueError("metadata_model_config is required") - # get metadata model instance - # fetch model config - model_instance, model_config = self.get_model_config(metadata_model_config) + # get metadata model instance and fetch model config + model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, metadata_fields=all_metadata_fields, query=query or "", ) - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query=query, memory=None, @@ -458,16 +527,23 @@ class KnowledgeRetrievalNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" try: # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.metadata_model_config, # type: ignore + generator = LLMNode.invoke_llm( + node_data_model=node_data.metadata_model_config, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -557,17 +633,13 @@ class KnowledgeRetrievalNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: KnowledgeRetrievalNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) + variable_mapping = {} - variable_mapping[node_id + ".query"] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: @@ -629,7 +701,7 @@ class KnowledgeRetrievalNode(LLMNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore + model_mode = ModelMode(node_data.metadata_model_config.mode) input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3c9ba44cf1..ae9401b056 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ -from collections.abc import Callable, Sequence -from typing import Any, Literal, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Literal, Optional, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError -class ListOperatorNode(BaseNode[ListOperatorNodeData]): - _node_data_cls = ListOperatorNodeData +class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR + _node_data: ListOperatorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ListOperatorNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): process_data: dict[str, list] = {} outputs: dict[str, Any] = {} - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" + error_message = f"Variable not found for selector: {self._node_data.variable}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( - f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" ) return NodeRunResult( @@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): try: # Filter - if self.node_data.filter_by.enabled: + if self._node_data.filter_by.enabled: variable = self._apply_filter(variable) # Extract - if self.node_data.extract_by.enabled: + if self._node_data.extract_by.enabled: variable = self._extract_slice(variable) # Order - if self.node_data.order_by.enabled: + if self._node_data.order_by.enabled: variable = self._apply_order(variable) # Slice - if self.node_data.limit.enabled: + if self._node_data.limit.enabled: variable = self._apply_slice(variable) outputs = { @@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: filter_func: Callable[[Any], bool] result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: + for condition in self._node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") @@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + result = _order_string(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + result = _order_number(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) return variable @@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): def _apply_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - result = variable.value[: self.node_data.limit.size] + result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) def _extract_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) + value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") value -= 1 diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 36d0688807..4bb62d35a2 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: dict | None = None + structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index be0675a0f2..91e7312805 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ( ModelInvokeCompletedEvent, NodeEvent, @@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState logger = logging.getLogger(__name__) -class LLMNode(BaseNode[LLMNodeData]): - _node_data_cls = LLMNodeData +class LLMNode(BaseNode): _node_type = NodeType.LLM + _node_data: LLMNodeData + # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] @@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LLMNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]): try: # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) + inputs = self._fetch_inputs(node_data=self._node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) + jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data) # merge inputs inputs.update(jinja_inputs) @@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]): files = ( llm_utils.fetch_files( variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, + selector=self._node_data.vision.configs.variable_selector, ) - if self.node_data.vision.enabled + if self._node_data.vision.enabled else [] ) @@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data=self.node_data) + generator = self._fetch_context(node_data=self._node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#context#"] = context # fetch model config - model_instance, model_config = self._fetch_model_config(self.node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=self._node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, - node_data_memory=self.node_data.memory, + node_data_memory=self._node_data.memory, model_instance=model_instance, ) query = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template + if self._node_data.memory: + query = self._node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, memory=memory, model_config=model_config, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, + prompt_template=self._node_data.prompt_template, + memory_config=self._node_data.memory, + vision_enabled=self._node_data.vision.enabled, + vision_detail=self._node_data.vision.configs.detail, variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, + jinja2_variables=self._node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, ) # handle invoke result - generator = self._invoke_llm( - node_data_model=self.node_data.model, + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) structured_output: LLMStructuredOutput | None = None @@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) - def _invoke_llm( - self, + @staticmethod + def invoke_llm( + *, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, + user_id: str, + structured_output_enabled: bool, + structured_output: Optional[Mapping[str, Any]] = None, + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: model_schema = model_instance.model_type_instance.get_model_schema( node_data_model.name, model_instance.credentials @@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if self.node_data.structured_output_enabled: - output_schema = self._fetch_structured_output_schema() + if structured_output_enabled: + output_schema = LLMNode.fetch_structured_output_schema( + structured_output=structured_output or {}, + ) invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, @@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) else: invoke_result = model_instance.invoke_llm( @@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) - return self._handle_invoke_result(invoke_result=invoke_result) + return LLMNode.handle_invoke_result( + invoke_result=invoke_result, + file_saver=file_saver, + file_outputs=file_outputs, + node_id=node_id, + ) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] + @staticmethod + def handle_invoke_result( + *, + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): - event = self._handle_blocking_result(invoke_result=invoke_result) + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=file_saver, + file_outputs=file_outputs, + ) yield event return @@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]): yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=contents, + file_saver=file_saver, + file_outputs=file_outputs, + ): full_text_buffer.write(text_part) - yield RunStreamChunkEvent( - chunk_content=text_part, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"]) # Update the whole metadata if not model and result.model: @@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]): yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - def _image_file_to_markdown(self, file: "File", /): + @staticmethod + def _image_file_to_markdown(file: "File", /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]): return None + @staticmethod def _fetch_model_config( - self, node_data_model: ModelConfig + *, + node_data_model: ModelConfig, + tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data_model + tenant_id=tenant_id, node_data_model=node_data_model ) completion_params = model_config_with_cred.parameters @@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]): node_data_model.completion_params = completion_params return model, model_config_with_cred - def _fetch_prompt_messages( - self, + @staticmethod + def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence["File"], @@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]): vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], + tenant_id: str, ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]): edition_type="basic", ) prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) model = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model, @@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LLMNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - prompt_template = node_data.prompt_template + # Create typed NodeData from dict + typed_node_data = LLMNodeData.model_validate(node_data) + prompt_template = typed_node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list) and all( isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template @@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = node_data.memory + memory = typed_node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector + if typed_node_data.context.enabled: + variable_mapping["#context#"] = typed_node_data.context.variable_selector - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector + if typed_node_data.vision.enabled: + variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector - if node_data.memory: + if typed_node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] - if node_data.prompt_config: + if typed_node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): @@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]): enable_jinja = True if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: + for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]): }, } - def _handle_list_messages( - self, + @staticmethod + def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str], @@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages - def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + @staticmethod + def handle_blocking_result( + *, + invoke_result: LLMResult, + saver: LLMFileSaver, + file_outputs: list["File"], + ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=invoke_result.message.content, + file_saver=saver, + file_outputs=file_outputs, + ): buffer.write(text_part) return ModelInvokeCompletedEvent( @@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]): finish_reason=None, ) - def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + @staticmethod + def save_multimodal_image_output( + *, + content: ImagePromptMessageContent, + file_saver: LLMFileSaver, + ) -> "File": """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]): Currently, only image files are supported. """ - # Inject the saver somehow... - _saver = self._llm_file_saver - - # If this if content.url != "": - saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: - saved_file = _saver.save_binary_string( + saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) - self._file_outputs.append(saved_file) return saved_file def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: """ Fetch model schema """ - model_name = self.node_data.model.name + model_name = self._node_data.model.name model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name @@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]): model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_schema - def _fetch_structured_output_schema(self) -> dict[str, Any]: + @staticmethod + def fetch_structured_output_schema( + *, + structured_output: Mapping[str, Any], + ) -> dict[str, Any]: """ Fetch the structured output schema from the node data. Returns: dict[str, Any]: The structured output schema """ - if not self.node_data.structured_output: + if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") @@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") + @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( - self, + *, contents: str | list[PromptMessageContentUnionTypes] | None, + file_saver: LLMFileSaver, + file_outputs: list["File"], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. @@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): - file = self._save_multimodal_image_output(item) - self._file_outputs.append(file) - yield self._image_file_to_markdown(file) + file = LLMNode.save_multimodal_image_output( + content=item, + file_saver=file_saver, + ) + file_outputs.append(file) + yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) @@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]): logger.warning("unknown contents type encountered, type=%s", type(contents)) yield str(contents) + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled + def _combine_message_content_with_role( *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index b144021bab..53cadc5251 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData -class LoopEndNode(BaseNode[LoopEndNodeData]): +class LoopEndNode(BaseNode): """ Loop End Node. """ - _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + _node_data: LoopEndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopEndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 20501d0317..655de9362f 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config from core.variables import ( @@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor @@ -43,14 +44,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(BaseNode[LoopNodeData]): +class LoopNode(BaseNode): """ Loop Node. """ - _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + _node_data: LoopNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]): def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator + loop_count = self._node_data.loop_count + break_conditions = self._node_data.break_conditions + logical_operator = self._node_data.logical_operator inputs = {"loop_count": loop_count} - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) + loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]): # Initialize loop variables loop_variable_selectors = {} - if self.node_data.loop_variables: - for loop_variable in self.node_data.loop_variables: + if self._node_data.loop_variables: + for loop_variable in self._node_data.loop_variables: value_processor = { "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value), @@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunStartedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"loop_length": loop_count}, @@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunSucceededEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, steps=loop_count, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, @@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]): WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, inputs=inputs, ) ) @@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=loop_count, @@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]): _outputs[loop_variable_key] = None _outputs["loop_round"] = current_index + 1 - self.node_data.outputs = _outputs + self._node_data.outputs = _outputs if check_break_result: return {"check_break_result": True} @@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunNextEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, index=next_index, - pre_loop_output=self.node_data.outputs, + pre_loop_output=self._node_data.outputs, ) return {"check_break_result": False} @@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LoopNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = LoopNodeData.model_validate(node_data) + variable_mapping = {} # init graph - loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in node_data.loop_variables or []: + for loop_variable in typed_node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5e38b7516..29b45ea0c3 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData -class LoopStartNode(BaseNode[LoopStartNodeData]): +class LoopStartNode(BaseNode): """ Loop Start Node. """ - _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + _node_data: LoopStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 25a534256b..a23d284626 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -29,8 +29,9 @@ from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser from factories.variable_factory import build_segment_with_type @@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode): Parameter Extractor Node. """ - # FIXME: figure out why here is different from super class - _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR + _node_data: ParameterExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ParameterExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None @@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self.node_data) + node_data = cast(ParameterExtractorNodeData, self._node_data) variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" @@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode): """ Generate prompt engineering prompt. """ - model_mode = ModelMode.value_of(data.model.mode) + model_mode = ModelMode(data.model.mode) if model_mode == ModelMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( @@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ) -> list[ChatModelMessage]: - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} + # Create typed NodeData from dict + typed_node_data = ParameterExtractorNodeData.model_validate(node_data) - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} + + if typed_node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 74024ed90c..15012fa48d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( LLMNode, @@ -20,6 +23,7 @@ from core.workflow.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown @@ -35,17 +39,77 @@ from .template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_3, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState -class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData # type: ignore + +class QuestionClassifierNode(BaseNode): _node_type = NodeType.QUESTION_CLASSIFIER + _node_data: QuestionClassifierNodeData + + _file_outputs: list["File"] + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = QuestionClassifierNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self.node_data) + node_data = cast(QuestionClassifierNodeData, self._node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables @@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode): query = variable.value if variable else None variables = {"query": query} # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" @@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode): try: # handle invoke result - generator = self._invoke_llm( + generator = LLMNode.invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=False, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Any, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(QuestionClassifierNodeData, node_data) - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) + # Create typed NodeData from dict + typed_node_data = QuestionClassifierNodeData.model_validate(node_data) + + variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_selectors: list[VariableSelector] = [] + if typed_node_data.instruction: + variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index e215591888..9e401e76bb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,15 +1,41 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.start.entities import StartNodeData -class StartNode(BaseNode[StartNodeData]): - _node_data_cls = StartNodeData +class StartNode(BaseNode): _node_type = NodeType.START + _node_data: StartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = StartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index ba573074c3..1962c82db1 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): - _node_data_cls = TemplateTransformNodeData +class TemplateTransformNode(BaseNode): _node_type = NodeType.TEMPLATE_TRANSFORM + _node_data: TemplateTransformNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = TemplateTransformNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): def _run(self) -> NodeRunResult: # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variables[variable_name] = value.to_object() if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = TemplateTransformNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3853a5d920..c565ad15c1 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod -from core.model_runtime.entities.llm_entities import LLMUsage from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory @@ -37,14 +36,18 @@ from .exc import ( ) -class ToolNode(BaseNode[ToolNodeData]): +class ToolNode(BaseNode): """ Tool Node """ - _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + _node_data: ToolNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ToolNodeData.model_validate(data) + @classmethod def version(cls) -> str: return "1" @@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]): Run the tool node """ - node_data = cast(ToolNodeData, self.node_data) + node_data = cast(ToolNodeData, self._node_data) # fetch tool icon tool_info = { @@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]): try: from core.tools.tool_manager import ToolManager - variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None + variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool + self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) except ToolNodeError as e: yield RunCompletedEvent( @@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]): parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, for_log=True, ) # get conversation id @@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]): try: # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) + yield from self._transform_message( + messages=message_stream, + tool_info=tool_info, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_id=self.node_id, + ) except (PluginDaemonClientSideError, ToolInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]): messages: Generator[ToolInvokeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - agent_thoughts: Optional[list] = None, + user_id: str, + tenant_id: str, + node_id: str, ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] @@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]): # transform message and handle file storage message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=None, ) @@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]): files: list[File] = [] json: list[dict] = [] - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage: LLMUsage | None = None variables: dict[str, Any] = {} for message in message_stream: @@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: @@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]): files.append( file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } + # JSON message handling for tool node if message.message.json_object is not None: json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.VARIABLE: assert isinstance(message.message, ToolInvokeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") if variable_name not in variables: variables[variable_name] = "" variables[variable_name] += variable_value yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] ) else: variables[variable_name] = variable_value @@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]): dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + plugins = manager.list_plugins(tenant_id) try: current_plugin = next( plugin @@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]): builtin_tool = next( provider for provider in BuiltinToolManageService.list_builtin_tools( - self.user_id, - self.tenant_id, + user_id, + tenant_id, ) if provider.name == dict_metadata["provider"] ) @@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]): dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - id=message.message.id, - node_execution_id=self.id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=self.node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.id == agent_log.id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: - assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage) - yield RunRetrieverResourceEvent( - retriever_resources=message.message.retriever_resources, - context=message.message.context, - ) # Add agent_logs to outputs['json'] to ensure frontend can access thinking process json_output: list[dict[str, Any]] = [] - # Step 1: append each agent log as its own dict. - if agent_logs: - for log in agent_logs: - json_output.append( - { - "id": log.id, - "parent_id": log.parent_id, - "error": log.error, - "status": log.status, - "data": log.data, - "label": log.label, - "metadata": log.metadata, - "node_id": log.node_id, - } - ) # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] if json: json_output.extend(json) @@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, metadata={ - **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, - llm_usage=llm_usage, ) ) @@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ToolNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]): :param node_data: node data :return: """ + # Create typed NodeData from dict + typed_node_data = ToolNodeData.model_validate(node_data) + result = {} - for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] + for parameter_name in typed_node_data.tool_parameters: + input = typed_node_data.tool_parameters[parameter_name] if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() @@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 96bb3e793a..98127bbeb6 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,17 +1,41 @@ from collections.abc import Mapping +from typing import Any, Optional from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAggregatorNode(BaseNode): _node_type = NodeType.VARIABLE_AGGREGATOR + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: + if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled: + for selector in self._node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: outputs = {"output": variable} @@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): inputs = {".".join(selector[1:]): variable.to_object()} break else: - for group in self.node_data.advanced_settings.groups: + for group in self._node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get(selector) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 1864b13784..51383fa588 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory @@ -22,11 +23,33 @@ if TYPE_CHECKING: _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] -class VariableAssignerNode(BaseNode[VariableAssignerData]): - _node_data_cls = VariableAssignerData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + _node_data: VariableAssignerData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def __init__( self, id: str, @@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + # Create typed NodeData from dict + typed_node_data = VariableAssignerData.model_validate(node_data) - selector_key = ".".join(node_data.input_variable_selector) + mapping = {} + assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] + if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: + selector_key = ".".join(typed_node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = typed_node_data.assigned_variable_selector + + selector_key = ".".join(typed_node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector + mapping[key] = typed_node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector + assigned_variable_selector = self._node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") - match self.node_data.write_mode: + match self._node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") + raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 9292da6f1c..c0215cae71 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json -from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any, TypeAlias, cast +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory @@ -28,8 +29,6 @@ from .exc import ( VariableNotFoundError, ) -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ mapping[key] = selector -class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory() @@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerNodeData.model_validate(node_data) + var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: + for item in typed_node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() + inputs = self._node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] try: - for item in self.node_data.items: + for item in self._node_data.items: variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) # ==================== Validation Part diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 1399efcdb1..d2375da39c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.callbacks import WorkflowCallback @@ -146,7 +146,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance = node_cls( + node = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -190,17 +190,11 @@ class WorkflowEntry: try: # run node - generator = node_instance.run() + generator = node.run() except Exception as e: - logger.exception( - "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", - workflow.id, - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - return node_instance, generator + logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + return node, generator @classmethod def run_free_node( @@ -262,7 +256,7 @@ class WorkflowEntry: node_cls = cast(type[BaseNode], node_cls) # init workflow run state - node_instance: BaseNode = node_cls( + node: BaseNode = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -297,17 +291,12 @@ class WorkflowEntry: ) # run node - generator = node_instance.run() + generator = node.run() - return node_instance, generator + return node, generator except Exception as e: - logger.exception( - "error while running node_instance, node_id=%s, type=%s, version=%s", - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 677bc74237..e31f77607a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -465,10 +465,10 @@ class WorkflowService: node_id: str, ) -> WorkflowNodeExecution: try: - node_instance, generator = invoke_node_fn() + node, node_events = invoke_node_fn() node_run_result: NodeRunResult | None = None - for event in generator: + for event in node_events: if isinstance(event, RunCompletedEvent): node_run_result = event.run_result @@ -479,18 +479,18 @@ class WorkflowService: if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error: node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + "metadata": {"error_strategy": node.error_strategy}, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: node_run_result = NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": node_run_result.error, "error_type": node_run_result.error_type, }, @@ -509,10 +509,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node_instance = e.node_instance + node = e._node run_succeeded = False node_run_result = None - error = e.error + error = e._error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( @@ -520,8 +520,8 @@ class WorkflowService: workflow_id="", # This is a single-step execution, so no workflow ID index=1, node_id=node_id, - node_type=node_instance.node_type, - title=node_instance.node_data.title, + node_type=node.type_, + title=node.title, elapsed_time=time.perf_counter() - start_at, created_at=datetime.now(UTC).replace(tzinfo=None), finished_at=datetime.now(UTC).replace(tzinfo=None), diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 7c48d84d69..330ebfd54a 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -15,7 +15,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="test_tenant") + model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index daab974775..707b28e6d8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -66,6 +66,10 @@ def init_code_node(code_config: dict): config=code_config, ) + # Initialize node data + if "data" in code_config: + node.init_node_data(code_config["data"]) + return node @@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } - node.node_data = cast(CodeNodeData, node.node_data) + node._node_data = cast(CodeNodeData, node._node_data) # validate - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) def test_execute_code_output_object_list(): @@ -330,10 +334,10 @@ def test_execute_code_output_object_list(): ] } - node.node_data = cast(CodeNodeData, node.node_data) + node._node_data = cast(CodeNodeData, node._node_data) # validate - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -353,7 +357,7 @@ def test_execute_code_output_object_list(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) def test_execute_code_scientific_notation(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 50e726febf..d7856129a3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -52,7 +52,7 @@ def init_http_node(config: dict): variable_pool.add(["a", "b123", "args1"], 1) variable_pool.add(["a", "b123", "args2"], 2) - return HttpRequestNode( + node = HttpRequestNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, @@ -60,6 +60,12 @@ def init_http_node(config: dict): config=config, ) + # Initialize node data + if "data" in config: + node.init_node_data(config["data"]) + + return node + @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index ff119b7482..a14791bc67 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -2,15 +2,10 @@ import json import time import uuid from collections.abc import Generator -from decimal import Decimal from unittest.mock import MagicMock, patch -import pytest - from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.graph_engine.entities.graph import Graph @@ -24,8 +19,6 @@ from models.enums import UserFrom from models.workflow import WorkflowType """FOR MOCK FIXTURES, DO NOT REMOVE""" -from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock -from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock def init_llm_node(config: dict) -> LLMNode: @@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode: config=config, ) + # Initialize node data + if "data" in config: + node.init_node_data(config["data"]) + return node -def test_execute_llm(flask_req_ctx): +def test_execute_llm(): node = init_llm_node( config={ "id": "llm", @@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx): "title": "123", "type": "llm", "model": { - "provider": "langgenius/openai/openai", + "provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}, @@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx): }, ) - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal(1000), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal(1000), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + db.session.close = MagicMock() - mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") + # Mock the _fetch_model_config to avoid database calls + def mock_fetch_model_config(**_kwargs): + from decimal import Decimal + from unittest.mock import MagicMock - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from core.model_runtime.entities.message_entities import AssistantPromptMessage - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create mock model instance + mock_model_instance = MagicMock() + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) + mock_message = AssistantPromptMessage(content="Test response from mock") + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "langgenius/openai/openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create mock model config + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.parameters = {} - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Mock fetch_prompt_messages to avoid database calls + def mock_fetch_prompt_messages_1(**_kwargs): + from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + + return [ + SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), + UserPromptMessage(content="what's the weather today?"), + ], [] with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), + patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1), ): # execute node result = node._run() @@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx): for item in result: if isinstance(item, RunCompletedEvent): + if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED: + print(f"Error: {item.run_result.error}") + print(f"Error type: {item.run_result.error_type}") assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.process_data is not None assert item.run_result.outputs is not None @@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx): assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 -@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): +def test_execute_llm_with_jinja2(): """ Test execute LLM node with jinja2 """ @@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): # Mock db.session.close() db.session.close = MagicMock() - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal(1000), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal(1000), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) - - mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) - - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result - - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): + def mock_fetch_model_config(**_kwargs): + from decimal import Decimal + from unittest.mock import MagicMock + + from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from core.model_runtime.entities.message_entities import AssistantPromptMessage + + # Create mock model instance + mock_model_instance = MagicMock() + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create mock model config + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.parameters = {} + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Mock fetch_prompt_messages to avoid database calls + def mock_fetch_prompt_messages_2(**_kwargs): + from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + + return [ + SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), + UserPromptMessage(content="what's the weather today?"), + ], [] with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), + patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2), ): # execute node result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index dd8466afa6..edd70193a8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict): variable_pool.add(["a", "b123", "args1"], 1) variable_pool.add(["a", "b123", "args2"], 2) - return ParameterExtractorNode( + node = ParameterExtractorNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config=config, ) + node.init_node_data(config.get("data", {})) + return node def test_function_calling_parameter_extractor(setup_model_mock): diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 1f617fc92d..f71a5ee140 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock): graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config=config, ) + node.init_node_data(config.get("data", {})) # execute node result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 6907e0163e..8476c1f874 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -50,13 +50,15 @@ def init_tool_node(config: dict): conversation_variables=[], ) - return ToolNode( + node = ToolNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config=config, ) + node.init_node_data(config.get("data", {})) + return node def test_tool_variable_invoke(): diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 85ff4f9c05..1ef024f46b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -58,21 +58,26 @@ def test_execute_answer(): pool.add(["start", "weather"], "sunny") pool.add(["llm", "text"], "You are a helpful AI.") + node_config = { + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + } + node = AnswerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 33f9251a72..71b3a8f7d8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch): ), ), ) + + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch): start_at=0, ), ) + + # Initialize node data + node.init_node_data(node_config["data"]) monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda *args, **kwargs: b"test", @@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch): ), ), ) + + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch): start_at=0, ), ) + + # Initialize node data + node.init_node_data(node_config["data"]) + monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda *args, **kwargs: b"test", @@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ), ) + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ), ) + # Initialize node data + node.init_node_data(node_config["data"]) + monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 17c23b7735..f53f391433 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -162,25 +162,30 @@ def test_run(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "tt", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - }, + config=node_config, ) + # Initialize node data + iteration_node.init_node_data(node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -379,25 +384,30 @@ def test_run_parallel(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - }, + config=node_config, ) + # Initialize node data + iteration_node.init_node_data(node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + parallel_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + } + parallel_iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - }, + config=parallel_node_config, ) + + # Initialize node data + parallel_iteration_node.init_node_data(parallel_node_config["data"]) + sequential_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + } + sequential_iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - }, + config=sequential_node_config, ) + # Initialize node data + sequential_iteration_node.init_node_data(sequential_node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode(): # execute node parallel_result = parallel_iteration_node._run() sequential_result = sequential_iteration_node._run() - assert parallel_iteration_node.node_data.parallel_nums == 10 - assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + assert parallel_iteration_node._node_data.parallel_nums == 10 + assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED count = 0 parallel_arr = [] sequential_arr = [] @@ -818,26 +838,31 @@ def test_iteration_run_error_handle(): environment_variables=[], ) pool.add(["pe", "list_output"], ["1", "1"]) + error_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "iteration", - "type": "iteration", - "is_parallel": True, - "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, - }, - "id": "iteration-1", - }, + config=error_node_config, ) + + # Initialize node data + iteration_node.init_node_data(error_node_config["data"]) # execute continue on error node result = iteration_node._run() result_arr = [] @@ -851,7 +876,7 @@ def test_iteration_run_error_handle(): assert count == 14 # execute remove abnormal output - iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT result = iteration_node._run() count = 0 for item in result: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fefad0ec95..23a7fab7cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -119,17 +119,20 @@ def llm_node( llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState ) -> LLMNode: mock_file_saver = mock.MagicMock(spec=LLMFileSaver) + node_config = { + "id": "1", + "data": llm_node_data.model_dump(), + } node = LLMNode( id="1", - config={ - "id": "1", - "data": llm_node_data.model_dump(), - }, + config=node_config, graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) + # Initialize node data + node.init_node_data(node_config["data"]) return node @@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node): variable_pool = llm_node.graph_runtime_state.variable_pool vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH - result = llm_node._handle_list_messages( + result = llm_node.handle_list_messages( messages=messages, context=context, jinja2_variables=jinja2_variables, @@ -506,17 +509,20 @@ def llm_node_for_multimodal( llm_node_data, graph_init_params, graph, graph_runtime_state ) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) + node_config = { + "id": "1", + "data": llm_node_data.model_dump(), + } node = LLMNode( id="1", - config={ - "id": "1", - "data": llm_node_data.model_dump(), - }, + config=node_config, graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) + # Initialize node data + node.init_node_data(node_config["data"]) return node, mock_file_saver @@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput: size=9, ) mock_file_saver.save_binary_string.return_value = mock_file - file = llm_node._save_multimodal_image_output(content=content) + file = llm_node.save_multimodal_image_output( + content=content, + file_saver=mock_file_saver, + ) + # Manually append to _file_outputs since the static method doesn't do it + llm_node._file_outputs.append(file) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_binary_string.assert_called_once_with( @@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput: size=9, ) mock_file_saver.save_remote_url.return_value = mock_file - file = llm_node._save_multimodal_image_output(content=content) + file = llm_node.save_multimodal_image_output( + content=content, + file_saver=mock_file_saver, + ) + # Manually append to _file_outputs since the static method doesn't do it + llm_node._file_outputs.append(file) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) @@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode): class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_str_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents="hello world", file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() @@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_text_prompt_message_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( - [TextPromptMessageContent(data="hello world")] + contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[] ) assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() @@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: ) mock_file_saver.save_binary_string.return_value = mock_saved_file gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( - [ + contents=[ ImagePromptMessageContent( format="png", base64_data=image_b64_data, mime_type="image/png", ) - ] + ], + file_saver=mock_file_saver, + file_outputs=llm_node._file_outputs, ) yielded_strs = list(gen) assert len(yielded_strs) == 1 @@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_unknown_content_type(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["frozenset({'hello world'})"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() def test_unknown_item_type(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["frozenset({'hello world'})"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() def test_none_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=None, file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == [] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 44c31b212e..466d7bad06 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -61,21 +61,26 @@ def test_execute_answer(): variable_pool.add(["start", "weather"], "sunny") variable_pool.add(["llm", "text"], "You are a helpful AI.") + node_config = { + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + } + node = AnswerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 66c7818adf..486ae51e5f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -27,13 +27,17 @@ def document_extractor_node(): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - return DocumentExtractorNode( + node_config = {"id": "test_node_id", "data": node_data.model_dump()} + node = DocumentExtractorNode( id="test_node_id", - config={"id": "test_node_id", "data": node_data.model_dump()}, + config=node_config, graph_init_params=Mock(), graph=Mock(), graph_runtime_state=Mock(), ) + # Initialize node data + node.init_node_data(node_config["data"]) + return node @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 167a92484d..8383aee0e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -57,57 +57,62 @@ def test_execute_if_else_result_true(): pool.add(["start", "null"], None) pool.add(["start", "not_null"], "1212") + node_config = { + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", + }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, + { + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", + }, + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "if-else", - "data": { - "title": "123", - "type": "if-else", - "logical_operator": "and", - "conditions": [ - { - "comparison_operator": "contains", - "variable_selector": ["start", "array_contains"], - "value": "ab", - }, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "array_not_contains"], - "value": "ab", - }, - {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "not_contains"], - "value": "ab", - }, - {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, - {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, - {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, - {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, - {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, - {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, - {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, - {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, - {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, - {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, - { - "comparison_operator": "≥", - "variable_selector": ["start", "greater_than_or_equal"], - "value": "22", - }, - {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, - {"comparison_operator": "null", "variable_selector": ["start", "null"]}, - {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() @@ -162,33 +167,38 @@ def test_execute_if_else_result_false(): pool.add(["start", "array_contains"], ["1ab", "def"]) pool.add(["start", "array_not_contains"], ["ab", "def"]) + node_config = { + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "if-else", - "data": { - "title": "123", - "type": "if-else", - "logical_operator": "or", - "conditions": [ - { - "comparison_operator": "contains", - "variable_selector": ["start", "array_contains"], - "value": "ab", - }, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "array_not_contains"], - "value": "ab", - }, - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() @@ -228,17 +238,22 @@ def test_array_file_contains_file_name(): ], ) + node_config = { + "id": "if-else", + "data": node_data.model_dump(), + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=Mock(), graph=Mock(), graph_runtime_state=Mock(), - config={ - "id": "if-else", - "data": node_data.model_dump(), - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 7d3a1d6a2d..5fc9eab2df 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -33,16 +33,19 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData(**config) + node_config = { + "id": "test_node_id", + "data": node_data.model_dump(), + } node = ListOperatorNode( id="test_node_id", - config={ - "id": "test_node_id", - "data": node_data.model_dump(), - }, + config=node_config, graph_init_params=MagicMock(), graph=MagicMock(), graph_runtime_state=MagicMock(), ) + # Initialize node data + node.init_node_data(node_config["data"]) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 2776e57777..0eaabd0c40 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -38,12 +38,13 @@ def _create_tool_node(): system_variables=SystemVariable.empty(), user_inputs={}, ) + node_config = { + "id": "1", + "data": data.model_dump(), + } node = ToolNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -71,6 +72,8 @@ def _create_tool_node(): start_at=0, ), ) + # Initialize node data + node.init_node_data(node_config["data"]) return node diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 62e3e37104..ee51339427 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -82,23 +82,28 @@ def test_overwrite_string_variable(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_var = StringVariable( id=conversation_variable.id, @@ -178,23 +183,28 @@ def test_append_variable_to_array(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_value = list(conversation_variable.value) expected_value.append(input_variable.value) @@ -265,23 +275,28 @@ def test_clear_array(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR.value, - "input_variable_selector": [], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_var = ArrayStringVariable( id=conversation_variable.id, diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index a3a90b0599..987eaf7534 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -115,28 +115,33 @@ def test_remove_first_from_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_FIRST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment # Print the variable before running print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") @@ -202,28 +207,33 @@ def test_remove_last_from_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_LAST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run()) @@ -281,28 +291,33 @@ def test_remove_first_from_empty_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_FIRST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run()) @@ -360,28 +375,33 @@ def test_remove_last_from_empty_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_LAST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run())