mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 23:18:39 +08:00
121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Literal, TypeAlias, cast
|
|
|
|
from pydantic import BaseModel, TypeAdapter, field_validator
|
|
from pydantic_core.core_schema import ValidationInfo
|
|
from typing_extensions import TypedDict
|
|
|
|
from core.tools.entities.tool_entities import ToolProviderType
|
|
from dify_graph.entities.base_node_data import BaseNodeData
|
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
|
|
|
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
|
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
|
VariableSelector: TypeAlias = list[str]
|
|
|
|
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
|
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
|
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
|
|
|
|
|
class WorkflowToolInputValue(TypedDict):
|
|
type: Literal["mixed", "variable", "constant"]
|
|
value: ToolInputConstantValue | VariableSelector
|
|
|
|
|
|
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
|
|
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
|
|
|
|
|
|
class ToolInputPayload(BaseModel):
|
|
type: Literal["mixed", "variable", "constant"]
|
|
value: ToolInputConstantValue | VariableSelector
|
|
|
|
@field_validator("value", mode="before")
|
|
@classmethod
|
|
def validate_value(
|
|
cls, value: object, validation_info: ValidationInfo
|
|
) -> ToolInputConstantValue | VariableSelector:
|
|
input_type = validation_info.data.get("type")
|
|
if input_type == "mixed":
|
|
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
|
if input_type == "variable":
|
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
|
if input_type == "constant":
|
|
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
|
raise ValueError(f"Unknown tool input type: {input_type}")
|
|
|
|
def require_variable_selector(self) -> VariableSelector:
|
|
if self.type != "variable":
|
|
raise ValueError(f"Expected variable tool input, got {self.type}")
|
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
|
|
|
|
|
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
|
|
if isinstance(value, (str, int, float, bool)):
|
|
return cast(ToolConfigurationEntry, value)
|
|
|
|
if isinstance(value, dict):
|
|
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
|
|
|
|
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
|
|
|
|
|
|
class ToolEntity(BaseModel):
|
|
provider_id: str
|
|
provider_type: ToolProviderType
|
|
provider_name: str # redundancy
|
|
tool_name: str
|
|
tool_label: str # redundancy
|
|
tool_configurations: ToolConfigurations
|
|
credential_id: str | None = None
|
|
plugin_unique_identifier: str | None = None # redundancy
|
|
|
|
@field_validator("tool_configurations", mode="before")
|
|
@classmethod
|
|
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
|
if not isinstance(value, dict):
|
|
raise TypeError("tool_configurations must be a dictionary")
|
|
|
|
normalized: ToolConfigurations = {}
|
|
for key, item in value.items():
|
|
if not isinstance(key, str):
|
|
raise TypeError("tool_configurations keys must be strings")
|
|
normalized[key] = _validate_tool_configuration_entry(item)
|
|
return normalized
|
|
|
|
|
|
class ToolNodeData(BaseNodeData, ToolEntity):
|
|
type: NodeType = BuiltinNodeTypes.TOOL
|
|
|
|
class ToolInput(ToolInputPayload):
|
|
pass
|
|
|
|
tool_parameters: dict[str, ToolInput]
|
|
# The version of the tool parameter.
|
|
# If this value is None, it indicates this is a previous version
|
|
# and requires using the legacy parameter parsing rules.
|
|
tool_node_version: str | None = None
|
|
|
|
@field_validator("tool_parameters", mode="before")
|
|
@classmethod
|
|
def filter_none_tool_inputs(cls, value: object) -> object:
|
|
if not isinstance(value, dict):
|
|
return value
|
|
|
|
return {
|
|
key: tool_input
|
|
for key, tool_input in value.items()
|
|
if tool_input is not None and cls._has_valid_value(tool_input)
|
|
}
|
|
|
|
@staticmethod
|
|
def _has_valid_value(tool_input: object) -> bool:
|
|
"""Check if the value is valid"""
|
|
if isinstance(tool_input, dict):
|
|
return tool_input.get("value") is not None
|
|
if isinstance(tool_input, ToolNodeData.ToolInput):
|
|
return tool_input.value is not None
|
|
return False
|