mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
fix(api): restore workflow node compatibility
This commit is contained in:
parent
9a86f280eb
commit
4c1d27431b
@ -698,12 +698,19 @@ def _refresh_model(session: Session, model: Workflow) -> Workflow: ...
|
||||
def _refresh_model(session: Session, model: Message) -> Message: ...
|
||||
|
||||
|
||||
def _refresh_model(session: Session, model: Workflow | Message) -> Workflow | Message:
|
||||
if isinstance(model, Workflow):
|
||||
detached_workflow = session.get(Workflow, model.id)
|
||||
assert detached_workflow is not None
|
||||
return detached_workflow
|
||||
def _refresh_model(session: Session, model: Any) -> Any:
|
||||
del session
|
||||
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
|
||||
if isinstance(model, Workflow):
|
||||
detached_workflow = refresh_session.get(Workflow, model.id)
|
||||
assert detached_workflow is not None
|
||||
return detached_workflow
|
||||
|
||||
detached_message = session.get(Message, model.id)
|
||||
assert detached_message is not None
|
||||
return detached_message
|
||||
if isinstance(model, Message):
|
||||
detached_message = refresh_session.get(Message, model.id)
|
||||
assert detached_message is not None
|
||||
return detached_message
|
||||
|
||||
detached_model = refresh_session.get(type(model), model.id)
|
||||
assert detached_model is not None
|
||||
return detached_model
|
||||
|
||||
@ -102,6 +102,25 @@ class LLMNodeData(BaseNodeData):
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@field_validator("structured_output", mode="before")
|
||||
@classmethod
|
||||
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
|
||||
if not isinstance(v, Mapping):
|
||||
return v
|
||||
|
||||
schema = v.get("schema")
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
normalized: StructuredOutputConfig = {"schema": schema}
|
||||
name = v.get("name")
|
||||
description = v.get("description")
|
||||
if isinstance(name, str):
|
||||
normalized["name"] = name
|
||||
if isinstance(description, str):
|
||||
normalized["description"] = description
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
from typing import Literal, TypeAlias, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -10,16 +10,57 @@ from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
|
||||
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
|
||||
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class WorkflowToolInputValue(TypedDict):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
|
||||
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
|
||||
|
||||
|
||||
class ToolInputPayload(BaseModel):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
|
||||
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return cast(ToolConfigurationEntry, value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
|
||||
|
||||
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
@ -33,34 +74,22 @@ class ToolEntity(BaseModel):
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("tool_configurations must be a dictionary")
|
||||
|
||||
normalized: ToolConfigurations = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("tool_configurations keys must be strings")
|
||||
normalized[key] = _validate_tool_configuration_entry(item)
|
||||
return normalized
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = BuiltinNodeTypes.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
class ToolInput(ToolInputPayload):
|
||||
pass
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user