diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 1694f856cf..44cd1ec3c6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -698,12 +698,19 @@ def _refresh_model(session: Session, model: Workflow) -> Workflow: ... def _refresh_model(session: Session, model: Message) -> Message: ... -def _refresh_model(session: Session, model: Workflow | Message) -> Workflow | Message: - if isinstance(model, Workflow): - detached_workflow = session.get(Workflow, model.id) - assert detached_workflow is not None - return detached_workflow +def _refresh_model(session: Session, model: Any) -> Any: + del session + with Session(bind=db.engine, expire_on_commit=False) as refresh_session: + if isinstance(model, Workflow): + detached_workflow = refresh_session.get(Workflow, model.id) + assert detached_workflow is not None + return detached_workflow - detached_message = session.get(Message, model.id) - assert detached_message is not None - return detached_message + if isinstance(model, Message): + detached_message = refresh_session.get(Message, model.id) + assert detached_message is not None + return detached_message + + detached_model = refresh_session.get(type(model), model.id) + assert detached_model is not None + return detached_model diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index ec6f572807..add95e7ab5 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -102,6 +102,25 @@ class LLMNodeData(BaseNodeData): return PromptConfig() return v + @field_validator("structured_output", mode="before") + @classmethod + def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object: + if not isinstance(v, Mapping): + return v + + schema = v.get("schema") + if schema is None: + return None + + normalized: StructuredOutputConfig = {"schema": schema} + name = v.get("name") + description = v.get("description") + if isinstance(name, str): + normalized["name"] = name + if isinstance(description, str): + normalized["description"] = description + return normalized + @property def structured_output_enabled(self) -> bool: return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index 56ff3f58d5..2692dbc61b 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal, TypeAlias +from typing import Literal, TypeAlias, TypedDict, cast from pydantic import BaseModel, TypeAdapter, field_validator from pydantic_core.core_schema import ValidationInfo @@ -10,16 +10,57 @@ from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType ToolConfigurationValue: TypeAlias = str | int | float | bool -ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue] ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None VariableSelector: TypeAlias = list[str] -_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations) _TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str) _TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue) _VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector) +class WorkflowToolInputValue(TypedDict): + type: Literal["mixed", "variable", "constant"] + value: ToolInputConstantValue | VariableSelector + + +ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue +ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry] + + +class ToolInputPayload(BaseModel): + type: Literal["mixed", "variable", "constant"] + value: ToolInputConstantValue | VariableSelector + + @field_validator("value", mode="before") + @classmethod + def validate_value( + cls, value: object, validation_info: ValidationInfo + ) -> ToolInputConstantValue | VariableSelector: + input_type = validation_info.data.get("type") + if input_type == "mixed": + return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value) + if input_type == "variable": + return _VARIABLE_SELECTOR_ADAPTER.validate_python(value) + if input_type == "constant": + return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value) + raise ValueError(f"Unknown tool input type: {input_type}") + + def require_variable_selector(self) -> VariableSelector: + if self.type != "variable": + raise ValueError(f"Expected variable tool input, got {self.type}") + return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value) + + +def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry: + if isinstance(value, (str, int, float, bool)): + return cast(ToolConfigurationEntry, value) + + if isinstance(value, dict): + return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump()) + + raise TypeError("Tool configuration values must be primitives or workflow tool input objects") + + class ToolEntity(BaseModel): provider_id: str provider_type: ToolProviderType @@ -33,34 +74,22 @@ class ToolEntity(BaseModel): @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations: - return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value) + if not isinstance(value, dict): + raise TypeError("tool_configurations must be a dictionary") + + normalized: ToolConfigurations = {} + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError("tool_configurations keys must be strings") + normalized[key] = _validate_tool_configuration_entry(item) + return normalized class ToolNodeData(BaseNodeData, ToolEntity): type: NodeType = BuiltinNodeTypes.TOOL - class ToolInput(BaseModel): - type: Literal["mixed", "variable", "constant"] - value: ToolInputConstantValue | VariableSelector - - @field_validator("value", mode="before") - @classmethod - def validate_value( - cls, value: object, validation_info: ValidationInfo - ) -> ToolInputConstantValue | VariableSelector: - input_type = validation_info.data.get("type") - if input_type == "mixed": - return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value) - if input_type == "variable": - return _VARIABLE_SELECTOR_ADAPTER.validate_python(value) - if input_type == "constant": - return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value) - raise ValueError(f"Unknown tool input type: {input_type}") - - def require_variable_selector(self) -> VariableSelector: - if self.type != "variable": - raise ValueError(f"Expected variable tool input, got {self.type}") - return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value) + class ToolInput(ToolInputPayload): + pass tool_parameters: dict[str, ToolInput] # The version of the tool parameter.