fix(api): restore workflow node compatibility

This commit is contained in:
Yanli 盐粒 2026-03-18 18:43:35 +08:00
parent 9a86f280eb
commit 4c1d27431b
3 changed files with 89 additions and 34 deletions

View File

@ -698,12 +698,19 @@ def _refresh_model(session: Session, model: Workflow) -> Workflow: ...
def _refresh_model(session: Session, model: Message) -> Message: ...
def _refresh_model(session: Session, model: Workflow | Message) -> Workflow | Message:
if isinstance(model, Workflow):
detached_workflow = session.get(Workflow, model.id)
assert detached_workflow is not None
return detached_workflow
def _refresh_model(session: Session, model: Any) -> Any:
del session
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
if isinstance(model, Workflow):
detached_workflow = refresh_session.get(Workflow, model.id)
assert detached_workflow is not None
return detached_workflow
detached_message = session.get(Message, model.id)
assert detached_message is not None
return detached_message
if isinstance(model, Message):
detached_message = refresh_session.get(Message, model.id)
assert detached_message is not None
return detached_message
detached_model = refresh_session.get(type(model), model.id)
assert detached_model is not None
return detached_model

View File

@ -102,6 +102,25 @@ class LLMNodeData(BaseNodeData):
return PromptConfig()
return v
@field_validator("structured_output", mode="before")
@classmethod
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
if not isinstance(v, Mapping):
return v
schema = v.get("schema")
if schema is None:
return None
normalized: StructuredOutputConfig = {"schema": schema}
name = v.get("name")
description = v.get("description")
if isinstance(name, str):
normalized["name"] = name
if isinstance(description, str):
normalized["description"] = description
return normalized
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Literal, TypeAlias
from typing import Literal, TypeAlias, TypedDict, cast
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
@ -10,16 +10,57 @@ from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class WorkflowToolInputValue(TypedDict):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
class ToolInputPayload(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
if isinstance(value, (str, int, float, bool)):
return cast(ToolConfigurationEntry, value)
if isinstance(value, dict):
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
@ -33,34 +74,22 @@ class ToolEntity(BaseModel):
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
if not isinstance(value, dict):
raise TypeError("tool_configurations must be a dictionary")
normalized: ToolConfigurations = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("tool_configurations keys must be strings")
normalized[key] = _validate_tool_configuration_entry(item)
return normalized
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
class ToolInput(ToolInputPayload):
pass
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.