mirror of https://github.com/langgenius/dify.git
[autofix.ci] apply automated fixes
This commit is contained in:
parent
2dd893e60d
commit
a099a35e51
|
|
@ -220,7 +220,7 @@ class OpsTraceManager:
|
|||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
|
|
@ -244,7 +244,7 @@ class OpsTraceManager:
|
|||
@classmethod
|
||||
def get_ops_trace_instance(
|
||||
cls,
|
||||
app_id: Optional[Union[UUID, str]] = None,
|
||||
app_id: Union[UUID, str] | None = None,
|
||||
):
|
||||
"""
|
||||
Get ops trace through model config
|
||||
|
|
@ -257,7 +257,7 @@ class OpsTraceManager:
|
|||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
return None
|
||||
|
|
@ -331,7 +331,7 @@ class OpsTraceManager:
|
|||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app_config:
|
||||
raise ValueError("App not found")
|
||||
app_config.tracing = json.dumps(
|
||||
|
|
@ -349,7 +349,7 @@ class OpsTraceManager:
|
|||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
|
|
@ -407,11 +407,11 @@ class TraceTask:
|
|||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
message_id: Optional[str] = None,
|
||||
message_id: str | None = None,
|
||||
workflow_execution: Optional["WorkflowExecution"] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
conversation_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
timer: Any | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.trace_type = trace_type
|
||||
|
|
@ -825,7 +825,7 @@ class TraceTask:
|
|||
return generate_name_trace_info
|
||||
|
||||
|
||||
trace_manager_timer: Optional[threading.Timer] = None
|
||||
trace_manager_timer: threading.Timer | None = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
|
||||
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class BaseIndexProcessor(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -51,10 +51,10 @@ class ToolEngine:
|
|||
message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
|
|
@ -152,9 +152,9 @@ class ToolEngine:
|
|||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
|
|
@ -194,9 +194,9 @@ class ToolEngine:
|
|||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
|
@ -446,7 +446,7 @@ class ToolManager:
|
|||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get tool runtime from plugin
|
||||
|
|
|
|||
|
|
@ -61,9 +61,9 @@ class WorkflowTool(Tool):
|
|||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke the tool
|
||||
|
|
|
|||
|
|
@ -74,9 +74,9 @@ class App(Base):
|
|||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
|
||||
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
|
||||
icon = mapped_column(String(255))
|
||||
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
icon_background: Mapped[str | None] = mapped_column(String(255))
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
|
|
@ -88,7 +88,7 @@ class App(Base):
|
|||
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
tracing = mapped_column(sa.Text, nullable=True)
|
||||
max_active_requests: Mapped[Optional[int]]
|
||||
max_active_requests: Mapped[int | None]
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -132,7 +132,7 @@ class App(Base):
|
|||
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
|
||||
|
||||
@property
|
||||
def tenant(self) -> Optional[Tenant]:
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
|
|
@ -290,7 +290,7 @@ class App(Base):
|
|||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> Optional[str]:
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
|
|
@ -333,7 +333,7 @@ class AppModelConfig(Base):
|
|||
file_upload = mapped_column(sa.Text)
|
||||
|
||||
@property
|
||||
def app(self) -> Optional[App]:
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
|
@ -545,7 +545,7 @@ class RecommendedApp(Base):
|
|||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self) -> Optional[App]:
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
|
@ -569,12 +569,12 @@ class InstalledApp(Base):
|
|||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self) -> Optional[App]:
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def tenant(self) -> Optional[Tenant]:
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
|
|
@ -710,7 +710,7 @@ class Conversation(Base):
|
|||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
app_model_config: AppModelConfig | None = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
|
|
@ -844,7 +844,7 @@ class Conversation(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def app(self) -> Optional[App]:
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
|
|
@ -858,7 +858,7 @@ class Conversation(Base):
|
|||
return None
|
||||
|
||||
@property
|
||||
def from_account_name(self) -> Optional[str]:
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
if account:
|
||||
|
|
@ -933,14 +933,14 @@ class Message(Base):
|
|||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
error = mapped_column(sa.Text)
|
||||
message_metadata = mapped_column(sa.Text)
|
||||
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
@property
|
||||
def inputs(self) -> dict[str, Any]:
|
||||
|
|
@ -1337,9 +1337,9 @@ class MessageFile(Base):
|
|||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
url: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -1356,8 +1356,8 @@ class MessageAnnotation(Base):
|
|||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
question = mapped_column(sa.Text, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
|
|
@ -1729,18 +1729,18 @@ class MessageAgentThought(Base):
|
|||
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
|
||||
tool_process_data = mapped_column(sa.Text, nullable=True)
|
||||
message = mapped_column(sa.Text, nullable=True)
|
||||
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
message_files = mapped_column(sa.Text, nullable=True)
|
||||
answer = mapped_column(sa.Text, nullable=True)
|
||||
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
total_price = mapped_column(sa.Numeric, nullable=True)
|
||||
currency = mapped_column(String, nullable=True)
|
||||
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
|
@ -1838,11 +1838,11 @@ class DatasetRetrieverResource(Base):
|
|||
document_name = mapped_column(sa.Text, nullable=False)
|
||||
data_source_type = mapped_column(sa.Text, nullable=True)
|
||||
segment_id = mapped_column(StringUUID, nullable=True)
|
||||
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
score: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(sa.Text, nullable=True)
|
||||
retriever_from = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
|
|
|
|||
|
|
@ -502,13 +502,13 @@ class ToolFile(TypeBase):
|
|||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# conversation id
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
# file key
|
||||
file_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# mime type
|
||||
mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original url
|
||||
original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None)
|
||||
original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None)
|
||||
# name
|
||||
name: Mapped[str] = mapped_column(default="")
|
||||
# size
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class WorkflowService:
|
|||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]:
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
|
|
@ -104,7 +104,7 @@ class WorkflowService:
|
|||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
"""
|
||||
|
|
@ -126,7 +126,7 @@ class WorkflowService:
|
|||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
|
@ -191,7 +191,7 @@ class WorkflowService:
|
|||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
|
|
@ -883,7 +883,7 @@ class WorkflowService:
|
|||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class TestContextPreservation:
|
|||
context = contextvars.copy_context()
|
||||
|
||||
# Variable to store value from worker
|
||||
worker_value: Optional[str] = None
|
||||
worker_value: str | None = None
|
||||
|
||||
def worker_task() -> None:
|
||||
nonlocal worker_value
|
||||
|
|
@ -120,7 +120,7 @@ class TestContextPreservation:
|
|||
test_node = MagicMock(spec=Node)
|
||||
|
||||
# Variable to capture context inside node execution
|
||||
captured_value: Optional[str] = None
|
||||
captured_value: str | None = None
|
||||
context_available_in_node = False
|
||||
|
||||
def mock_run() -> list[GraphNodeEventBase]:
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ class NodeMockConfig:
|
|||
|
||||
node_id: str
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
delay: float = 0.0 # Simulated execution delay in seconds
|
||||
custom_handler: Optional[Callable[..., dict[str, Any]]] = None
|
||||
custom_handler: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -51,7 +51,7 @@ class MockConfig:
|
|||
default_template_transform_response: str = "This is mocked template transform output"
|
||||
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
|
||||
|
||||
def get_node_config(self, node_id: str) -> Optional[NodeMockConfig]:
|
||||
def get_node_config(self, node_id: str) -> NodeMockConfig | None:
|
||||
"""Get configuration for a specific node."""
|
||||
return self.node_configs.get(node_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class MockNodeMixin:
|
|||
|
||||
return default_outputs
|
||||
|
||||
def _should_simulate_error(self) -> Optional[str]:
|
||||
def _should_simulate_error(self) -> str | None:
|
||||
"""Check if this node should simulate an error."""
|
||||
if not self.mock_config:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -60,14 +60,14 @@ class WorkflowTestCase:
|
|||
query: str = ""
|
||||
description: str = ""
|
||||
timeout: float = 30.0
|
||||
mock_config: Optional[MockConfig] = None
|
||||
mock_config: MockConfig | None = None
|
||||
use_auto_mock: bool = False
|
||||
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
|
||||
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skip: bool = False
|
||||
skip_reason: str = ""
|
||||
retry_count: int = 0
|
||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -76,14 +76,14 @@ class WorkflowTestResult:
|
|||
|
||||
test_case: WorkflowTestCase
|
||||
success: bool
|
||||
error: Optional[Exception] = None
|
||||
actual_outputs: Optional[dict[str, Any]] = None
|
||||
error: Exception | None = None
|
||||
actual_outputs: dict[str, Any] | None = None
|
||||
execution_time: float = 0.0
|
||||
event_sequence_match: Optional[bool] = None
|
||||
event_mismatch_details: Optional[str] = None
|
||||
event_sequence_match: bool | None = None
|
||||
event_mismatch_details: str | None = None
|
||||
events: list[GraphEngineEvent] = field(default_factory=list)
|
||||
retry_attempts: int = 0
|
||||
validation_details: Optional[str] = None
|
||||
validation_details: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -116,7 +116,7 @@ class TestSuiteResult:
|
|||
class WorkflowRunner:
|
||||
"""Core workflow execution engine for tests."""
|
||||
|
||||
def __init__(self, fixtures_dir: Optional[Path] = None):
|
||||
def __init__(self, fixtures_dir: Path | None = None):
|
||||
"""Initialize the workflow runner."""
|
||||
if fixtures_dir is None:
|
||||
# Use the new central fixtures location
|
||||
|
|
@ -147,9 +147,9 @@ class WorkflowRunner:
|
|||
self,
|
||||
fixture_data: dict[str, Any],
|
||||
query: str = "",
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
use_mock_factory: bool = False,
|
||||
mock_config: Optional[MockConfig] = None,
|
||||
mock_config: MockConfig | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
"""Create a Graph instance from fixture data."""
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
|
|
@ -240,7 +240,7 @@ class TableTestRunner:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
fixtures_dir: Optional[Path] = None,
|
||||
fixtures_dir: Path | None = None,
|
||||
max_workers: int = 4,
|
||||
enable_logging: bool = False,
|
||||
log_level: str = "INFO",
|
||||
|
|
@ -467,8 +467,8 @@ class TableTestRunner:
|
|||
self,
|
||||
expected_outputs: dict[str, Any],
|
||||
actual_outputs: dict[str, Any],
|
||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate actual outputs against expected outputs.
|
||||
|
||||
|
|
@ -517,7 +517,7 @@ class TableTestRunner:
|
|||
|
||||
def _validate_event_sequence(
|
||||
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate that actual events match the expected event sequence.
|
||||
|
||||
|
|
@ -549,7 +549,7 @@ class TableTestRunner:
|
|||
self,
|
||||
test_cases: list[WorkflowTestCase],
|
||||
parallel: bool = False,
|
||||
tags_filter: Optional[list[str]] = None,
|
||||
tags_filter: list[str] | None = None,
|
||||
fail_fast: bool = False,
|
||||
) -> TestSuiteResult:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue