From 99351d2f981c6995ae99e34af425b29432f401c1 Mon Sep 17 00:00:00 2001 From: Evan <2869018789@qq.com> Date: Fri, 12 Jun 2026 00:14:30 +0800 Subject: [PATCH] refactor: convert remaining isinstance chains to match/case (part 9) (#35902) (#37340) --- api/core/agent/cot_agent_runner.py | 13 ++--- api/core/agent/cot_completion_agent_runner.py | 15 +++--- api/core/app/apps/base_app_runner.py | 51 +++++++++--------- .../output_parser/structured_output.py | 13 ++--- api/core/mcp/client/sse_client.py | 13 ++--- api/core/mcp/session/client_session.py | 9 ++-- api/core/ops/entities/trace_entity.py | 17 +++--- api/core/ops/utils.py | 25 ++++----- api/core/plugin/impl/base.py | 13 ++--- api/core/prompt/advanced_prompt_transform.py | 53 ++++++++++--------- .../celery_workflow_execution_repository.py | 17 +++--- ...lery_workflow_node_execution_repository.py | 17 +++--- .../repositories/human_input_repository.py | 31 +++++------ ...qlalchemy_workflow_execution_repository.py | 17 +++--- ...hemy_workflow_node_execution_repository.py | 17 +++--- api/core/schemas/resolver.py | 9 ++-- .../logstore_api_workflow_run_repository.py | 26 ++++----- 17 files changed, 187 insertions(+), 169 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9b8bf566c15..9c9fa1092f6 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -208,12 +208,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: - if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False) - elif isinstance(scratchpad.action.action_input, str): - final_answer = scratchpad.action.action_input - else: - final_answer = f"{scratchpad.action.action_input}" + match scratchpad.action.action_input: + case dict(): + final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False) + case str(): + final_answer = scratchpad.action.action_input + case _: + final_answer = f"{scratchpad.action.action_input}" except TypeError: final_answer = f"{scratchpad.action.action_input}" else: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index fd46dbc2fa5..72d7831eb73 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -43,13 +43,14 @@ class CotCompletionAgentRunner(CotAgentRunner): case UserPromptMessage(): historic_prompt += f"Question: {message.content}\n\n" case AssistantPromptMessage(): - if isinstance(message.content, str): - historic_prompt += message.content + "\n\n" - elif isinstance(message.content, list): - for content in message.content: - if not isinstance(content, TextPromptMessageContent): - continue - historic_prompt += content.data + match message.content: + case str(): + historic_prompt += message.content + "\n\n" + case list(): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index ea4a187a9c7..15f6359929a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -301,31 +301,32 @@ class AppRunner: queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) message = result.delta.message - if isinstance(message.content, str): - text += message.content - elif isinstance(message.content, list): - for content in message.content: - match content: - case str(): - text += content - case TextPromptMessageContent(): - text += content.data - case ImagePromptMessageContent(): - if message_id and user_id and tenant_id: - try: - self._handle_multimodal_image_content( - content=content, - message_id=message_id, - user_id=user_id, - tenant_id=tenant_id, - queue_manager=queue_manager, - ) - except Exception: - _logger.exception("Failed to handle multimodal image output") - else: - _logger.warning("Received multimodal output but missing required parameters") - case _: - text += content.data if hasattr(content, "data") else str(content) + match message.content: + case str(): + text += message.content + case list(): + for content in message.content: + match content: + case str(): + text += content + case TextPromptMessageContent(): + text += content.data + case ImagePromptMessageContent(): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") + case _: + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 6cba4fbdf66..f2e98244197 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -166,12 +166,13 @@ def invoke_llm_with_structured_output( prompt_messages = event.prompt_messages system_fingerprint = event.system_fingerprint - if isinstance(event.delta.message.content, str): - result_text += event.delta.message.content - elif isinstance(event.delta.message.content, list): - for item in event.delta.message.content: - if isinstance(item, TextPromptMessageContent): - result_text += item.data + match event.delta.message.content: + case str(): + result_text += event.delta.message.content + case list(): + for item in event.delta.message.content: + if isinstance(item, TextPromptMessageContent): + result_text += item.data yield LLMResultChunkWithStructuredOutput( model=model_schema.model, diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 19d977c8e58..28ecf290c32 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -211,12 +211,13 @@ class SSETransport: except queue.Empty: raise ValueError("failed to get endpoint URL") - if isinstance(status, _StatusReady): - return status.endpoint_url - elif isinstance(status, _StatusError): - raise status.exc - else: - raise ValueError("failed to get endpoint URL") + match status: + case _StatusReady(): + return status.endpoint_url + case _StatusError(): + raise status.exc + case _: + raise ValueError("failed to get endpoint URL") def connect( self, diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index f91295a4323..1f1f574afab 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -41,10 +41,11 @@ class MessageHandlerFnT(Protocol): def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ): - if isinstance(message, Exception): - raise ValueError(str(message)) - elif isinstance(message, (types.ServerNotification | RequestResponder)): - pass + match message: + case Exception(): + raise ValueError(str(message)) + case types.ServerNotification() | RequestResponder(): + pass def _default_sampling_callback( diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 98e87a0ceb0..e183ad7f340 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -62,15 +62,16 @@ class BaseTraceInfo(BaseModel): parent_span_id_source is the outer node_execution_id. """ parent_ctx = self.metadata.get("parent_trace_context") - if isinstance(parent_ctx, ParentTraceContext): - context = parent_ctx - elif isinstance(parent_ctx, Mapping): - try: - context = ParentTraceContext.model_validate(parent_ctx) - except ValueError: + match parent_ctx: + case ParentTraceContext(): + context = parent_ctx + case Mapping(): + try: + context = ParentTraceContext.model_validate(parent_ctx) + except ValueError: + return None, None + case _: return None, None - else: - return None, None return ( context.parent_workflow_run_id, context.parent_node_execution_id, diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index a6f10c09acc..50cccd9d088 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -38,18 +38,19 @@ def measure_time(): def replace_text_with_content(data): - if isinstance(data, dict): - new_data = {} - for key, value in data.items(): - if key == "text": - new_data["content"] = value - else: - new_data[key] = replace_text_with_content(value) - return new_data - elif isinstance(data, list): - return [replace_text_with_content(item) for item in data] - else: - return data + match data: + case dict(): + new_data = {} + for key, value in data.items(): + if key == "text": + new_data["content"] = value + else: + new_data[key] = replace_text_with_content(value) + return new_data + case list(): + return [replace_text_with_content(item) for item in data] + case _: + return data def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str: diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index c034662cf4f..7a74b89cf51 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -45,12 +45,13 @@ _plugin_daemon_timeout_config = cast( getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0), ) plugin_daemon_request_timeout: httpx.Timeout | None -if _plugin_daemon_timeout_config is None: - plugin_daemon_request_timeout = None -elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout): - plugin_daemon_request_timeout = _plugin_daemon_timeout_config -else: - plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config) +match _plugin_daemon_timeout_config: + case None: + plugin_daemon_request_timeout = None + case httpx.Timeout(): + plugin_daemon_request_timeout = _plugin_daemon_timeout_config + case _: + plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config) logger = logging.getLogger(__name__) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 5a9914e6e4c..e7c88811fe5 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -50,32 +50,33 @@ class AdvancedPromptTransform(PromptTransform): ) -> list[PromptMessage]: prompt_messages = [] - if isinstance(prompt_template, CompletionModelPromptTemplate): - prompt_messages = self._get_completion_model_prompt_messages( - prompt_template=prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - model_instance=model_instance, - image_detail_config=image_detail_config, - ) - elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): - prompt_messages = self._get_chat_model_prompt_messages( - prompt_template=prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - model_instance=model_instance, - image_detail_config=image_detail_config, - ) + match prompt_template: + case CompletionModelPromptTemplate(): + prompt_messages = self._get_completion_model_prompt_messages( + prompt_template=prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config, + model_instance=model_instance, + image_detail_config=image_detail_config, + ) + case list() if all(isinstance(item, ChatModelMessage) for item in prompt_template): + prompt_messages = self._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config, + model_instance=model_instance, + image_detail_config=image_detail_config, + ) return prompt_messages diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d65c71abc81..bf7c8d48f8e 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -61,14 +61,15 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) """ # Store session factory for fallback operations - if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) - elif isinstance(session_factory, sessionmaker): - self._session_factory = session_factory - else: - raise ValueError( - f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" - ) + match session_factory: + case Engine(): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + case sessionmaker(): + self._session_factory = session_factory + case _: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) # Extract tenant_id from user tenant_id = extract_tenant_id(user) diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index dc2588b489f..f48d92f8797 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -68,14 +68,15 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) """ # Store session factory for fallback operations - if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) - elif isinstance(session_factory, sessionmaker): - self._session_factory = session_factory - else: - raise ValueError( - f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" - ) + match session_factory: + case Engine(): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + case sessionmaker(): + self._session_factory = session_factory + case _: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) # Extract tenant_id from user tenant_id = extract_tenant_id(user) diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 7a97b328388..4d1a3ef0063 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -288,24 +288,25 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): - recipient_model = HumanInputFormRecipient( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(), - ) - recipients.append(recipient_model) - elif isinstance(delivery_method, EmailDeliveryMethod): - email_recipients_config = delivery_method.config.recipients - recipients.extend( - self._build_email_recipients( - session=session, + match delivery_method: + case InteractiveSurfaceDeliveryMethod(): + recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, - recipients_config=email_recipients_config, + recipient_type=RecipientType.STANDALONE_WEB_APP, + recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(), + ) + recipients.append(recipient_model) + case EmailDeliveryMethod(): + email_recipients_config = delivery_method.config.recipients + recipients.extend( + self._build_email_recipients( + session=session, + form_id=form_id, + delivery_id=delivery_id, + recipients_config=email_recipients_config, + ) ) - ) return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 4bca8b34e87..0e9f842731d 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -54,14 +54,15 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) """ # If an engine is provided, create a sessionmaker from it - if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) - elif isinstance(session_factory, sessionmaker): - self._session_factory = session_factory - else: - raise ValueError( - f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" - ) + match session_factory: + case Engine(): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + case sessionmaker(): + self._session_factory = session_factory + case _: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) # Extract tenant_id from user tenant_id = extract_tenant_id(user) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7eda458f85b..65324028f82 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -77,14 +77,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) """ # If an engine is provided, create a sessionmaker from it - if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) - elif isinstance(session_factory, sessionmaker): - self._session_factory = session_factory - else: - raise ValueError( - f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" - ) + match session_factory: + case Engine(): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + case sessionmaker(): + self._session_factory = session_factory + case _: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) # Extract tenant_id from user tenant_id = extract_tenant_id(user) diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index e267c1abd9a..cd86aebc060 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -125,10 +125,11 @@ class SchemaResolver: def _process_queue_item(self, queue: deque, item: QueueItem) -> None: """Process a single queue item""" - if isinstance(item.current, dict): - self._process_dict(queue, item) - elif isinstance(item.current, list): - self._process_list(queue, item) + match item.current: + case dict(): + self._process_dict(queue, item) + case list(): + self._process_list(queue, item) def _process_dict(self, queue: deque, item: QueueItem) -> None: """Process a dictionary item""" diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index cdc7d129fd4..66028fb85bb 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -106,24 +106,26 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: # Handle datetime fields started_at = data.get("started_at") or data.get("created_at") if started_at: - if isinstance(started_at, str): - model.created_at = datetime.fromisoformat(started_at) - elif isinstance(started_at, (int, float)): - model.created_at = datetime.fromtimestamp(started_at) - else: - model.created_at = started_at + match started_at: + case str(): + model.created_at = datetime.fromisoformat(started_at) + case int() | float(): + model.created_at = datetime.fromtimestamp(started_at) + case _: + model.created_at = started_at else: # Provide default created_at if missing model.created_at = datetime.now() finished_at = data.get("finished_at") if finished_at: - if isinstance(finished_at, str): - model.finished_at = datetime.fromisoformat(finished_at) - elif isinstance(finished_at, (int, float)): - model.finished_at = datetime.fromtimestamp(finished_at) - else: - model.finished_at = finished_at + match finished_at: + case str(): + model.finished_at = datetime.fromisoformat(finished_at) + case int() | float(): + model.finished_at = datetime.fromtimestamp(finished_at) + case _: + model.finished_at = finished_at # Compute elapsed_time from started_at and finished_at # LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity