refactor: convert remaining isinstance chains to match/case (part 9) (#35902) (#37340)

This commit is contained in:
Evan 2026-06-12 00:14:30 +08:00 committed by GitHub
parent 7ec295fd66
commit 99351d2f98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 187 additions and 169 deletions

View File

@ -208,12 +208,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
match scratchpad.action.action_input:
case dict():
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
case str():
final_answer = scratchpad.action.action_input
case _:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:

View File

@ -43,13 +43,14 @@ class CotCompletionAgentRunner(CotAgentRunner):
case UserPromptMessage():
historic_prompt += f"Question: {message.content}\n\n"
case AssistantPromptMessage():
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
match message.content:
case str():
historic_prompt += message.content + "\n\n"
case list():
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt

View File

@ -301,31 +301,32 @@ class AppRunner:
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
message = result.delta.message
if isinstance(message.content, str):
text += message.content
elif isinstance(message.content, list):
for content in message.content:
match content:
case str():
text += content
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
case _:
text += content.data if hasattr(content, "data") else str(content)
match message.content:
case str():
text += message.content
case list():
for content in message.content:
match content:
case str():
text += content
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
case _:
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model

View File

@ -166,12 +166,13 @@ def invoke_llm_with_structured_output(
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content
elif isinstance(event.delta.message.content, list):
for item in event.delta.message.content:
if isinstance(item, TextPromptMessageContent):
result_text += item.data
match event.delta.message.content:
case str():
result_text += event.delta.message.content
case list():
for item in event.delta.message.content:
if isinstance(item, TextPromptMessageContent):
result_text += item.data
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,

View File

@ -211,12 +211,13 @@ class SSETransport:
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if isinstance(status, _StatusReady):
return status.endpoint_url
elif isinstance(status, _StatusError):
raise status.exc
else:
raise ValueError("failed to get endpoint URL")
match status:
case _StatusReady():
return status.endpoint_url
case _StatusError():
raise status.exc
case _:
raise ValueError("failed to get endpoint URL")
def connect(
self,

View File

@ -41,10 +41,11 @@ class MessageHandlerFnT(Protocol):
def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
):
if isinstance(message, Exception):
raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)):
pass
match message:
case Exception():
raise ValueError(str(message))
case types.ServerNotification() | RequestResponder():
pass
def _default_sampling_callback(

View File

@ -62,15 +62,16 @@ class BaseTraceInfo(BaseModel):
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if isinstance(parent_ctx, ParentTraceContext):
context = parent_ctx
elif isinstance(parent_ctx, Mapping):
try:
context = ParentTraceContext.model_validate(parent_ctx)
except ValueError:
match parent_ctx:
case ParentTraceContext():
context = parent_ctx
case Mapping():
try:
context = ParentTraceContext.model_validate(parent_ctx)
except ValueError:
return None, None
case _:
return None, None
else:
return None, None
return (
context.parent_workflow_run_id,
context.parent_node_execution_id,

View File

@ -38,18 +38,19 @@ def measure_time():
def replace_text_with_content(data):
if isinstance(data, dict):
new_data = {}
for key, value in data.items():
if key == "text":
new_data["content"] = value
else:
new_data[key] = replace_text_with_content(value)
return new_data
elif isinstance(data, list):
return [replace_text_with_content(item) for item in data]
else:
return data
match data:
case dict():
new_data = {}
for key, value in data.items():
if key == "text":
new_data["content"] = value
else:
new_data[key] = replace_text_with_content(value)
return new_data
case list():
return [replace_text_with_content(item) for item in data]
case _:
return data
def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str:

View File

@ -45,12 +45,13 @@ _plugin_daemon_timeout_config = cast(
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:
plugin_daemon_request_timeout = None
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
match _plugin_daemon_timeout_config:
case None:
plugin_daemon_request_timeout = None
case httpx.Timeout():
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
case _:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
logger = logging.getLogger(__name__)

View File

@ -50,32 +50,33 @@ class AdvancedPromptTransform(PromptTransform):
) -> list[PromptMessage]:
prompt_messages = []
if isinstance(prompt_template, CompletionModelPromptTemplate):
prompt_messages = self._get_completion_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
prompt_messages = self._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
match prompt_template:
case CompletionModelPromptTemplate():
prompt_messages = self._get_completion_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
case list() if all(isinstance(item, ChatModelMessage) for item in prompt_template):
prompt_messages = self._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
return prompt_messages

View File

@ -61,14 +61,15 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
match session_factory:
case Engine():
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
case sessionmaker():
self._session_factory = session_factory
case _:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)

View File

@ -68,14 +68,15 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
match session_factory:
case Engine():
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
case sessionmaker():
self._session_factory = session_factory
case _:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)

View File

@ -288,24 +288,25 @@ class HumanInputFormRepositoryImpl:
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod):
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
)
recipients.append(recipient_model)
elif isinstance(delivery_method, EmailDeliveryMethod):
email_recipients_config = delivery_method.config.recipients
recipients.extend(
self._build_email_recipients(
session=session,
match delivery_method:
case InteractiveSurfaceDeliveryMethod():
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipients_config=email_recipients_config,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
)
recipients.append(recipient_model)
case EmailDeliveryMethod():
email_recipients_config = delivery_method.config.recipients
recipients.extend(
self._build_email_recipients(
session=session,
form_id=form_id,
delivery_id=delivery_id,
recipients_config=email_recipients_config,
)
)
)
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)

View File

@ -54,14 +54,15 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
match session_factory:
case Engine():
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
case sessionmaker():
self._session_factory = session_factory
case _:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)

View File

@ -77,14 +77,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
match session_factory:
case Engine():
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
case sessionmaker():
self._session_factory = session_factory
case _:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)

View File

@ -125,10 +125,11 @@ class SchemaResolver:
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
"""Process a single queue item"""
if isinstance(item.current, dict):
self._process_dict(queue, item)
elif isinstance(item.current, list):
self._process_list(queue, item)
match item.current:
case dict():
self._process_dict(queue, item)
case list():
self._process_list(queue, item)
def _process_dict(self, queue: deque, item: QueueItem) -> None:
"""Process a dictionary item"""

View File

@ -106,24 +106,26 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
# Handle datetime fields
started_at = data.get("started_at") or data.get("created_at")
if started_at:
if isinstance(started_at, str):
model.created_at = datetime.fromisoformat(started_at)
elif isinstance(started_at, (int, float)):
model.created_at = datetime.fromtimestamp(started_at)
else:
model.created_at = started_at
match started_at:
case str():
model.created_at = datetime.fromisoformat(started_at)
case int() | float():
model.created_at = datetime.fromtimestamp(started_at)
case _:
model.created_at = started_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
match finished_at:
case str():
model.finished_at = datetime.fromisoformat(finished_at)
case int() | float():
model.finished_at = datetime.fromtimestamp(finished_at)
case _:
model.finished_at = finished_at
# Compute elapsed_time from started_at and finished_at
# LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity