From 61196180b8c817c22c937b5fd46d024f72787147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Tue, 17 Mar 2026 19:31:00 +0800 Subject: [PATCH] Type phase 3 tool inputs --- api/core/tools/tool_manager.py | 5 +- api/core/workflow/nodes/agent/entities.py | 31 ++- .../workflow/nodes/agent/runtime_support.py | 200 ++++++++++++------ api/dify_graph/nodes/tool/entities.py | 75 +++---- api/dify_graph/nodes/tool/tool_node.py | 10 +- 5 files changed, 214 insertions(+), 107 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7f7787b92a..febce35985 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1040,9 +1040,10 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.require_variable_selector() + variable = variable_pool.get(variable_selector) if variable is None: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") + raise ToolParameterError(f"Variable {variable_selector} does not exist") parameter_value = variable.value elif tool_input.type == "constant": parameter_value = tool_input.value diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 91fed39795..b517bbddf1 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,13 +1,24 @@ -from enum import IntEnum, StrEnum, auto -from typing import Any, Literal, Union +from __future__ import annotations -from pydantic import BaseModel +from enum import IntEnum, StrEnum, auto +from typing import Literal, TypeAlias + +from pydantic import BaseModel, TypeAdapter, field_validator +from pydantic_core.core_schema import ValidationInfo from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType +AgentInputConstantValue: TypeAlias = ( + list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None +) +VariableSelector: TypeAlias = list[str] + +_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue) +_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector) + class AgentNodeData(BaseNodeData): type: NodeType = BuiltinNodeTypes.AGENT @@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData): tool_node_version: str | None = None class AgentInput(BaseModel): - value: Union[list[str], list[ToolSelector], Any] type: Literal["mixed", "variable", "constant"] + value: AgentInputConstantValue | VariableSelector + + @field_validator("value", mode="before") + @classmethod + def validate_value( + cls, value: object, validation_info: ValidationInfo + ) -> AgentInputConstantValue | VariableSelector: + input_type = validation_info.data.get("type") + if input_type == "variable": + return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value) + if input_type in {"mixed", "constant"}: + return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value) + raise ValueError(f"Unknown agent input type: {input_type}") agent_parameters: dict[str, AgentInput] diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..7b5518afca 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -1,16 +1,17 @@ from __future__ import annotations import json -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Mapping, Sequence +from typing import TypeAlias from packaging.version import Version -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter +from core.app.entities.app_invoke_entities import InvokeFrom from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.plugin.entities.request import InvokeCredentials @@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen from .exceptions import AgentInputTypeError, AgentVariableNotFoundError from .strategy_protocols import ResolvedAgentStrategy +JsonObject: TypeAlias = dict[str, object] +JsonObjectList: TypeAlias = list[JsonObject] +VariableSelector: TypeAlias = list[str] + +_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) +_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) +_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector) + class AgentRuntimeSupport: def build_parameters( @@ -39,12 +48,12 @@ class AgentRuntimeSupport: strategy: ResolvedAgentStrategy, tenant_id: str, app_id: str, - invoke_from: Any, + invoke_from: InvokeFrom, for_log: bool = False, - ) -> dict[str, Any]: + ) -> dict[str, object]: agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} - result: dict[str, Any] = {} + result: dict[str, object] = {} for parameter_name in node_data.agent_parameters: parameter = agent_parameters_dictionary.get(parameter_name) if not parameter: @@ -54,9 +63,10 @@ class AgentRuntimeSupport: agent_input = node_data.agent_parameters[parameter_name] match agent_input.type: case "variable": - variable = variable_pool.get(agent_input.value) # type: ignore[arg-type] + variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value) + variable = variable_pool.get(variable_selector) if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) + raise AgentVariableNotFoundError(str(variable_selector)) parameter_value = variable.value case "mixed" | "constant": try: @@ -79,60 +89,38 @@ class AgentRuntimeSupport: value = parameter_value if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - value = [tool for tool in value if tool.get("enabled", False)] - value = self._filter_mcp_type_tool(strategy, value) - for tool in value: - if "schemas" in tool: - tool.pop("schemas") - parameters = tool.get("parameters", {}) - if all(isinstance(v, dict) for _, v in parameters.items()): - params = {} - for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN) in ( - ParamsAutoGenerated.CLOSE, - 0, - ): - value_param = param.get("value", {}) - if value_param and value_param.get("type", "") == "variable": - variable_selector = value_param.get("value") - if not variable_selector: - raise ValueError("Variable selector is missing for a variable-type parameter.") - - variable = variable_pool.get(variable_selector) - if variable is None: - raise AgentVariableNotFoundError(str(variable_selector)) - - params[key] = variable.value - else: - params[key] = value_param.get("value", "") if value_param is not None else None - else: - params[key] = None - parameters = params - tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} - tool["parameters"] = parameters + tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value) + value = self._normalize_tool_payloads( + strategy=strategy, + tools=tool_payloads, + variable_pool=variable_pool, + ) if not for_log: if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) + value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) - setting_params = tool.get("settings", {}) - parameters = tool.get("parameters", {}) + provider_type = self._coerce_tool_provider_type(tool.get("type")) + setting_params = self._coerce_json_object(tool.get("settings")) or {} + parameters = self._coerce_json_object(tool.get("parameters")) or {} manual_input_params = [key for key, value in parameters.items() if value is not None] parameters = {**parameters, **setting_params} + provider_id = self._coerce_optional_string(tool.get("provider_name")) or "" + tool_name = self._coerce_optional_string(tool.get("tool_name")) or "" + plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier")) + credential_id = self._coerce_optional_string(tool.get("credential_id")) entity = AgentToolEntity( - provider_id=tool.get("provider_name", ""), + provider_id=provider_id, provider_type=provider_type, - tool_name=tool.get("tool_name", ""), + tool_name=tool_name, tool_parameters=parameters, - plugin_unique_identifier=tool.get("plugin_unique_identifier", None), - credential_id=tool.get("credential_id", None), + plugin_unique_identifier=plugin_unique_identifier, + credential_id=credential_id, ) - extra = tool.get("extra", {}) + extra = self._coerce_json_object(tool.get("extra")) or {} runtime_variable_pool: VariablePool | None = None if node_data.version != "1" or node_data.tool_node_version is not None: @@ -145,8 +133,9 @@ class AgentRuntimeSupport: runtime_variable_pool, ) if tool_runtime.entity.description: + description_override = self._coerce_optional_string(extra.get("description")) tool_runtime.entity.description.llm = ( - extra.get("description", "") or tool_runtime.entity.description.llm + description_override or tool_runtime.entity.description.llm ) for tool_runtime_params in tool_runtime.entity.parameters: tool_runtime_params.form = ( @@ -167,13 +156,13 @@ class AgentRuntimeSupport: { **tool_runtime.entity.model_dump(mode="json"), "runtime_parameters": runtime_parameters, - "credential_id": tool.get("credential_id", None), + "credential_id": credential_id, "provider_type": provider_type.value, } ) value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: - value = cast(dict[str, Any], value) + value = _JSON_OBJECT_ADAPTER.validate_python(value) model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) history_prompt_messages = [] if node_data.memory: @@ -199,17 +188,27 @@ class AgentRuntimeSupport: return result - def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials: + def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials: credentials = InvokeCredentials() credentials.tool_credentials = {} - for tool in parameters.get("tools", []): + tools = parameters.get("tools") + if not isinstance(tools, list): + return credentials + + for raw_tool in tools: + tool = self._coerce_json_object(raw_tool) + if tool is None: + continue if not tool.get("credential_id"): continue try: identity = ToolIdentity.model_validate(tool.get("identity", {})) except ValidationError: continue - credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + credential_id = self._coerce_optional_string(tool.get("credential_id")) + if credential_id is None: + continue + credentials.tool_credentials[identity.provider] = credential_id return credentials def fetch_memory( @@ -232,14 +231,14 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]: provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=tenant_id, - provider=value.get("provider", ""), + provider=str(value.get("provider", "")), model_type=ModelType.LLM, ) - model_name = value.get("model", "") + model_name = str(value.get("model", "")) model_credentials = provider_model_bundle.configuration.get_current_credentials( model_type=ModelType.LLM, model=model_name, @@ -249,7 +248,7 @@ class AgentRuntimeSupport: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, provider=provider_name, - model_type=ModelType(value.get("model_type", "")), + model_type=ModelType(str(value.get("model_type", ""))), model=model_name, ) model_schema = model_type_instance.get_model_schema(model_name, model_credentials) @@ -268,9 +267,88 @@ class AgentRuntimeSupport: @staticmethod def _filter_mcp_type_tool( strategy: ResolvedAgentStrategy, - tools: list[dict[str, Any]], - ) -> list[dict[str, Any]]: + tools: JsonObjectList, + ) -> JsonObjectList: meta_version = strategy.meta_version if meta_version and Version(meta_version) > Version("0.0.1"): return tools return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] + + def _normalize_tool_payloads( + self, + *, + strategy: ResolvedAgentStrategy, + tools: JsonObjectList, + variable_pool: VariablePool, + ) -> JsonObjectList: + enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))] + normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools) + for tool in normalized_tools: + tool.pop("schemas", None) + tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool) + tool["settings"] = self._resolve_tool_settings(tool) + return normalized_tools + + def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject: + parameter_configs = self._coerce_named_json_objects(tool.get("parameters")) + if parameter_configs is None: + raw_parameters = self._coerce_json_object(tool.get("parameters")) + return raw_parameters or {} + + resolved_parameters: JsonObject = {} + for key, parameter_config in parameter_configs.items(): + if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0): + value_param = self._coerce_json_object(parameter_config.get("value")) + if value_param and value_param.get("type") == "variable": + variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value")) + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + resolved_parameters[key] = variable.value + else: + resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None + else: + resolved_parameters[key] = None + + return resolved_parameters + + @staticmethod + def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject: + settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings")) + if settings is None: + return {} + return {key: setting.get("value") for key, setting in settings.items()} + + @staticmethod + def _coerce_json_object(value: object) -> JsonObject | None: + try: + return _JSON_OBJECT_ADAPTER.validate_python(value) + except ValidationError: + return None + + @staticmethod + def _coerce_optional_string(value: object) -> str | None: + return value if isinstance(value, str) else None + + @staticmethod + def _coerce_tool_provider_type(value: object) -> ToolProviderType: + if isinstance(value, ToolProviderType): + return value + if isinstance(value, str): + return ToolProviderType(value) + return ToolProviderType.BUILT_IN + + @classmethod + def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None: + if not isinstance(value, dict): + return None + + coerced: dict[str, JsonObject] = {} + for key, item in value.items(): + if not isinstance(key, str): + return None + json_object = cls._coerce_json_object(item) + if json_object is None: + return None + coerced[key] = json_object + return coerced diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index b041ee66fd..56ff3f58d5 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -1,12 +1,24 @@ -from typing import Any, Literal, Union +from __future__ import annotations -from pydantic import BaseModel, field_validator +from typing import Literal, TypeAlias + +from pydantic import BaseModel, TypeAdapter, field_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType +ToolConfigurationValue: TypeAlias = str | int | float | bool +ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue] +ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None +VariableSelector: TypeAlias = list[str] + +_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations) +_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str) +_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue) +_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector) + class ToolEntity(BaseModel): provider_id: str @@ -14,52 +26,41 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_configurations: dict[str, Any] + tool_configurations: ToolConfigurations credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value + def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations: + return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value) class ToolNodeData(BaseNodeData, ToolEntity): type: NodeType = BuiltinNodeTypes.TOOL class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] type: Literal["mixed", "variable", "constant"] + value: ToolInputConstantValue | VariableSelector - @field_validator("type", mode="before") + @field_validator("value", mode="before") @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") + def validate_value( + cls, value: object, validation_info: ValidationInfo + ) -> ToolInputConstantValue | VariableSelector: + input_type = validation_info.data.get("type") + if input_type == "mixed": + return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value) + if input_type == "variable": + return _VARIABLE_SELECTOR_ADAPTER.validate_python(value) + if input_type == "constant": + return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value) + raise ValueError(f"Unknown tool input type: {input_type}") - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ + def require_variable_selector(self) -> VariableSelector: + if self.type != "variable": + raise ValueError(f"Expected variable tool input, got {self.type}") + return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value) tool_parameters: dict[str, ToolInput] # The version of the tool parameter. @@ -69,7 +70,7 @@ class ToolNodeData(BaseNodeData, ToolEntity): @field_validator("tool_parameters", mode="before") @classmethod - def filter_none_tool_inputs(cls, value): + def filter_none_tool_inputs(cls, value: object) -> object: if not isinstance(value, dict): return value @@ -80,8 +81,10 @@ class ToolNodeData(BaseNodeData, ToolEntity): } @staticmethod - def _has_valid_value(tool_input): + def _has_valid_value(tool_input: object) -> bool: """Check if the value is valid""" if isinstance(tool_input, dict): return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None + if isinstance(tool_input, ToolNodeData.ToolInput): + return tool_input.value is not None + return False diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 598f0da92e..b549d6451a 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]): continue tool_input = node_data.tool_parameters[parameter_name] if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.require_variable_selector() + variable = variable_pool.get(variable_selector) if variable is None: if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") + raise ToolParameterError(f"Variable {variable_selector} does not exist") continue parameter_value = variable.value elif tool_input.type in {"mixed", "constant"}: @@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]): for selector in selectors: result[selector.variable] = selector.value_selector case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value + variable_selector = input.require_variable_selector() + selector_key = ".".join(variable_selector) + result[f"#{selector_key}#"] = variable_selector case "constant": pass