mirror of https://github.com/langgenius/dify.git
fix: type errors
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
00a1af8506
commit
b4c1766932
|
|
@ -1,9 +1,33 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeNodeProvider(BaseModel):
|
||||
class VariableConfig(TypedDict):
|
||||
variable: str
|
||||
value_selector: Sequence[str | int]
|
||||
|
||||
|
||||
class OutputConfig(TypedDict):
|
||||
type: str
|
||||
children: None
|
||||
|
||||
|
||||
class CodeConfig(TypedDict):
|
||||
variables: Sequence[VariableConfig]
|
||||
code_language: str
|
||||
code: str
|
||||
outputs: Mapping[str, OutputConfig]
|
||||
|
||||
|
||||
class DefaultConfig(TypedDict):
|
||||
type: str
|
||||
config: CodeConfig
|
||||
|
||||
|
||||
class CodeNodeProvider(BaseModel, ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_language() -> str:
|
||||
|
|
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
|
|||
pass
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls):
|
||||
def get_default_config(cls) -> DefaultConfig:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -7,4 +5,4 @@ class AgentNodeStrategyInit(BaseModel):
|
|||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: Optional[str] = None
|
||||
icon: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -10,10 +10,10 @@ class RunCondition(BaseModel):
|
|||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
branch_identify: str | None = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
conditions: list[Condition] | None = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class WorkflowExecution(BaseModel):
|
|||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
|
||||
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
error_message: str = Field(default="")
|
||||
|
|
@ -37,7 +37,7 @@ class WorkflowExecution(BaseModel):
|
|||
exceptions_count: int = Field(default=0)
|
||||
|
||||
started_at: datetime = Field(...)
|
||||
finished_at: Optional[datetime] = None
|
||||
finished_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -39,41 +39,41 @@ class WorkflowNodeExecution(BaseModel):
|
|||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: Optional[str] = None
|
||||
node_execution_id: str | None = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
|
||||
predecessor_node_id: str | None = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
error: str | None = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
finished_at: Optional[datetime] = None # When execution completed
|
||||
finished_at: datetime | None = None # When execution completed
|
||||
|
||||
def update_from_mapping(
|
||||
self,
|
||||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -14,4 +14,4 @@ class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
|
|||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ class GraphRunStartedEvent(BaseGraphEvent):
|
|||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
|
|
@ -20,11 +20,11 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
|||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: Optional[str] = Field(default=None, description="reason for abort")
|
||||
outputs: Optional[dict[str, Any]] = Field(default=None, description="partial outputs if any")
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, Any] | None = Field(default=None, description="partial outputs if any")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
|
|||
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
|
|||
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Optional[Any] = None
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -12,9 +11,9 @@ from .base import GraphNodeEventBase
|
|||
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
predecessor_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -14,5 +14,5 @@ class AgentLogEvent(NodeEventBase):
|
|||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="node id")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -9,28 +9,28 @@ from .base import NodeEventBase
|
|||
|
||||
class IterationStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class IterationNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class IterationSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -9,28 +9,28 @@ from .base import NodeEventBase
|
|||
|
||||
class LoopStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class LoopNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Optional[Any] = None
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class LoopSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class LoopFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -77,7 +77,7 @@ class AgentNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -86,7 +86,7 @@ class AgentNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -414,7 +414,7 @@ class AgentNode(Node):
|
|||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
|
|
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
|
|||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
|
|
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
|
|||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
|
|
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
|
|||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
|
|||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
|
|||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
|
|||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
|
|||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
|
|||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
|
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
|
|||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
|
|
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
|
|||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -20,7 +20,7 @@ class AnswerNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -29,7 +29,7 @@ class AnswerNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
|
@ -128,10 +128,10 @@ class DefaultValue(BaseModel):
|
|||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
|
|
@ -142,7 +142,7 @@ class BaseNodeData(ABC, BaseModel):
|
|||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
|
|
@ -157,7 +157,7 @@ class BaseIterationState(BaseModel):
|
|||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
|
|
|
|||
|
|
@ -145,11 +145,10 @@ class Node:
|
|||
for event in result:
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event = self._dispatch(event)
|
||||
|
||||
if not event.in_iteration_id and not event.in_loop_id:
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
yield event
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
|
@ -30,7 +30,7 @@ class CodeNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -39,7 +39,7 @@ class CodeNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -49,7 +49,7 @@ class CodeNode(Node):
|
|||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
|
@ -57,7 +57,7 @@ class CodeNode(Node):
|
|||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
|
|
@ -154,7 +154,7 @@ class CodeNode(Node):
|
|||
def _transform_result(
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
output_schema: dict[str, CodeNodeData.Output] | None,
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Annotated, Literal, Optional
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
|||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
children: dict[str, Self] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
|
@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData):
|
|||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
dependencies: list[Dependency] | None = None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
|
|
@ -49,7 +49,7 @@ class DocumentExtractorNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -58,7 +58,7 @@ class DocumentExtractorNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -18,7 +18,7 @@ class EndNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -27,7 +27,7 @@ class EndNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import mimetypes
|
||||
from collections.abc import Sequence
|
||||
from email.message import Message
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
|
|||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: Optional[HttpRequestNodeAuthorizationConfig] = None
|
||||
config: HttpRequestNodeAuthorizationConfig | None = None
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
|
|||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[HttpRequestNodeBody] = None
|
||||
timeout: Optional[HttpRequestNodeTimeout] = None
|
||||
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
body: HttpRequestNodeBody | None = None
|
||||
timeout: HttpRequestNodeTimeout | None = None
|
||||
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
|
||||
class Response:
|
||||
|
|
@ -183,7 +183,7 @@ class Response:
|
|||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
@property
|
||||
def parsed_content_disposition(self) -> Optional[Message]:
|
||||
def parsed_content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
|
|
@ -39,7 +39,7 @@ class HttpRequestNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -48,7 +48,7 @@ class HttpRequestNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
|
|||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
|
|||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
cases: Optional[list[Case]] = None
|
||||
cases: list[Case] | None = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class IfElseNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -31,7 +31,7 @@ class IfElseNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
|
|||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
|
|
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
|
|||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
|
|
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
|
|||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
|
|
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
|
|||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class IterationNode(Node):
|
|||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict[str, object] | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -20,7 +20,7 @@ class IterationStartNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -29,7 +29,7 @@ class IterationStartNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
|
|||
"""
|
||||
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
|
|
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
|
|||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
|
|
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||
single_retrieval_config: SingleRetrievalConfig | None = None
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
|
|
@ -119,7 +119,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -410,7 +410,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
|
|
@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
|
|
@ -43,7 +43,7 @@ class ListOperatorNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -52,7 +52,7 @@ class ListOperatorNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
|
|||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
|
|
@ -51,18 +51,18 @@ class PromptConfig(BaseModel):
|
|||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: Optional[str] = None
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: Optional[MemoryConfig] = None
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
|||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
|
|
@ -139,7 +139,7 @@ class LLMNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -148,7 +148,7 @@ class LLMNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -354,10 +354,10 @@ class LLMNode(Node):
|
|||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Optional[Mapping[str, Any]] = None,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
|
|
@ -716,7 +716,7 @@ class LLMNode(Node):
|
|||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
|
|
@ -959,7 +959,7 @@ class LLMNode(Node):
|
|||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
|
|
@ -987,7 +987,7 @@ class LLMNode(Node):
|
|||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
|
|
@ -1175,7 +1175,7 @@ class LLMNode(Node):
|
|||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
|
|
@ -1281,7 +1281,7 @@ def _handle_memory_completion_mode(
|
|||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData):
|
|||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
|
|
@ -74,7 +74,7 @@ class LoopState(BaseLoopState):
|
|||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
|
|
@ -83,7 +83,7 @@ class LoopState(BaseLoopState):
|
|||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
|
|
@ -91,7 +91,7 @@ class LoopState(BaseLoopState):
|
|||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -20,7 +20,7 @@ class LoopEndNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -29,7 +29,7 @@ class LoopEndNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
|
|
@ -51,7 +51,7 @@ class LoopNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -60,7 +60,7 @@ class LoopNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -20,7 +20,7 @@ class LoopStartNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -29,7 +29,7 @@ class LoopStartNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
|
@ -48,7 +48,7 @@ class ParameterConfig(BaseModel):
|
|||
|
||||
name: str
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
options: list[str] | None = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
|
|
@ -86,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData):
|
|||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
|
|
@ -96,7 +96,7 @@ class ParameterExtractorNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -105,7 +105,7 @@ class ParameterExtractorNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -114,11 +114,11 @@ class ParameterExtractorNode(Node):
|
|||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
|
|
@ -293,7 +293,7 @@ class ParameterExtractorNode(Node):
|
|||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
|
|
@ -323,9 +323,9 @@ class ParameterExtractorNode(Node):
|
|||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
|
|
@ -405,9 +405,9 @@ class ParameterExtractorNode(Node):
|
|||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
|
|
@ -443,9 +443,9 @@ class ParameterExtractorNode(Node):
|
|||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
|
|
@ -477,9 +477,9 @@ class ParameterExtractorNode(Node):
|
|||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
|
|
@ -651,7 +651,7 @@ class ParameterExtractorNode(Node):
|
|||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
|
|
@ -666,7 +666,7 @@ class ParameterExtractorNode(Node):
|
|||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
|
|
@ -705,7 +705,7 @@ class ParameterExtractorNode(Node):
|
|||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
|
@ -732,7 +732,7 @@ class ParameterExtractorNode(Node):
|
|||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
|
@ -768,7 +768,7 @@ class ParameterExtractorNode(Node):
|
|||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
|
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
|
|||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -80,7 +80,7 @@ class QuestionClassifierNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -89,7 +89,7 @@ class QuestionClassifierNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
|
|||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters (not used in this implementation).
|
||||
|
|
@ -285,7 +285,7 @@ class QuestionClassifierNode(Node):
|
|||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
|
|
@ -328,7 +328,7 @@ class QuestionClassifierNode(Node):
|
|||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -18,7 +18,7 @@ class StartNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -27,7 +27,7 @@ class StartNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -20,7 +20,7 @@ class TemplateTransformNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -29,7 +29,7 @@ class TemplateTransformNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
|
|||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -470,7 +470,7 @@ class ToolNode(Node):
|
|||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -479,7 +479,7 @@ class ToolNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
|
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
|
|||
type: str = "variable-assigner"
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -17,7 +17,7 @@ class VariableAggregatorNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -26,7 +26,7 @@ class VariableAggregatorNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
|
|
@ -33,7 +33,7 @@ class VariableAssignerNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -42,7 +42,7 @@ class VariableAssignerNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel):
|
|||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||
# the resolved actual value that will be applied to the target variable.
|
||||
value: Any | None = None
|
||||
value: Any = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
|
|
@ -60,7 +60,7 @@ class VariableAssignerNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -69,7 +69,7 @@ class VariableAssignerNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ class OrderConfig:
|
|||
"""Configuration for ordering NodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
order_direction: Literal["asc", "desc"] | None = None
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
|
|
@ -42,7 +42,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
order_config: OrderConfig | None = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
|
|
@ -85,9 +85,9 @@ class WorkflowCycleManager:
|
|||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
external_trace_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
|
|
@ -112,9 +112,9 @@ class WorkflowCycleManager:
|
|||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
external_trace_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
|
|
@ -140,10 +140,10 @@ class WorkflowCycleManager:
|
|||
total_steps: int,
|
||||
status: WorkflowExecutionStatus,
|
||||
error_message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
exceptions_count: int = 0,
|
||||
external_trace_id: Optional[str] = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
|
@ -295,9 +295,9 @@ class WorkflowCycleManager:
|
|||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: Optional[datetime] = None,
|
||||
finished_at: datetime | None = None,
|
||||
):
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
|
|
@ -311,10 +311,10 @@ class WorkflowCycleManager:
|
|||
|
||||
def _add_trace_task_if_needed(
|
||||
self,
|
||||
trace_manager: Optional[TraceQueueManager],
|
||||
trace_manager: TraceQueueManager | None,
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: Optional[str],
|
||||
external_trace_id: Optional[str],
|
||||
conversation_id: str | None,
|
||||
external_trace_id: str | None,
|
||||
):
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
|
|
@ -401,7 +401,7 @@ class WorkflowCycleManager:
|
|||
QueueNodeExceptionEvent,
|
||||
],
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
error: str | None = None,
|
||||
handle_special_values: bool = False,
|
||||
):
|
||||
"""Update node execution with completion data."""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
|
|
@ -43,7 +43,7 @@ class WorkflowEntry:
|
|||
call_depth: int,
|
||||
variable_pool: VariablePool,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: Optional[CommandChannel] = None,
|
||||
command_channel: CommandChannel | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Init workflow entry
|
||||
|
|
@ -341,7 +341,7 @@ class WorkflowEntry:
|
|||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
|
|
|
|||
|
|
@ -543,12 +543,12 @@ class WorkflowService:
|
|||
# This will prevent validation errors from breaking the workflow
|
||||
return []
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
def get_default_block_configs(self) -> Sequence[Mapping[str, object]]:
|
||||
"""
|
||||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
default_block_configs = []
|
||||
default_block_configs: list[Mapping[str, object]] = []
|
||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
default_config = node_class.get_default_config()
|
||||
|
|
|
|||
Loading…
Reference in New Issue