refactor: convert isinstance chains to match/case (part 5) (#36298)

Co-authored-by: Stephen Zhou <hi@hyoban.cc>
This commit is contained in:
EvanYao 2026-05-18 15:16:31 +08:00 committed by GitHub
parent c07686928a
commit 7e8147295b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 175 additions and 182 deletions

View File

@ -195,22 +195,23 @@ class BaseAppGenerator:
)
if variable_entity.type == VariableEntityType.NUMBER:
if isinstance(value, (int, float)):
return value
elif isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
else:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match value:
case int() | float():
return value
case str():
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
case _:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match variable_entity.type:
case VariableEntityType.SELECT:
@ -241,17 +242,18 @@ class BaseAppGenerator:
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if isinstance(value, str):
normalized_value = value.strip().lower()
if normalized_value in {"true", "1", "yes", "on"}:
value = True
elif normalized_value in {"false", "0", "no", "off"}:
value = False
elif isinstance(value, (int, float)):
if value == 1:
value = True
elif value == 0:
value = False
match value:
case str():
normalized_value = value.strip().lower()
if normalized_value in {"true", "1", "yes", "on"}:
value = True
elif normalized_value in {"false", "0", "no", "off"}:
value = False
case int() | float():
if value == 1:
value = True
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")

View File

@ -105,52 +105,31 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._node_sequence = 0
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
return
if isinstance(event, GraphRunSucceededEvent):
self._handle_graph_run_succeeded(event)
return
if isinstance(event, GraphRunPartialSucceededEvent):
self._handle_graph_run_partial_succeeded(event)
return
if isinstance(event, GraphRunFailedEvent):
self._handle_graph_run_failed(event)
return
if isinstance(event, GraphRunAbortedEvent):
self._handle_graph_run_aborted(event)
return
if isinstance(event, GraphRunPausedEvent):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return
if isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
return
if isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
return
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
match event:
case GraphRunStartedEvent():
self._handle_graph_run_started()
case GraphRunSucceededEvent():
self._handle_graph_run_succeeded(event)
case GraphRunPartialSucceededEvent():
self._handle_graph_run_partial_succeeded(event)
case GraphRunFailedEvent():
self._handle_graph_run_failed(event)
case GraphRunAbortedEvent():
self._handle_graph_run_aborted(event)
case GraphRunPausedEvent():
self._handle_graph_run_paused(event)
case NodeRunRetryEvent():
self._handle_node_retry(event)
case NodeRunStartedEvent():
self._handle_node_started(event)
case NodeRunSucceededEvent():
self._handle_node_succeeded(event)
case NodeRunFailedEvent():
self._handle_node_failed(event)
case NodeRunExceptionEvent():
self._handle_node_exception(event)
case NodeRunPauseRequestedEvent():
self._handle_node_pause_requested(event)
def on_graph_end(self, error: Exception | None) -> None:
return

View File

@ -288,11 +288,13 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
except ValidationError:
# if the result_text is not a valid json, try to repair it
temp_parsed = json_repair.loads(result_text)
if not isinstance(temp_parsed, dict):
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
if isinstance(temp_parsed, list):
match temp_parsed:
case dict():
pass
case list():
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
else:
case _:
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = cast(dict, temp_parsed)
return structured_output
@ -341,12 +343,13 @@ def remove_additional_properties(schema: dict[str, Any]) -> None:
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
match value:
case dict():
remove_additional_properties(value)
case list():
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
@ -364,9 +367,10 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)
match value:
case dict():
convert_boolean_to_string(value)
case list():
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)

View File

@ -608,64 +608,67 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
if isinstance(message, CoreToolInvokeMessage.TextMessage):
return ToolRuntimeMessage.TextMessage(text=message.text)
if isinstance(message, CoreToolInvokeMessage.JsonMessage):
return ToolRuntimeMessage.JsonMessage(
json_object=message.json_object,
suppress_output=message.suppress_output,
)
if isinstance(message, CoreToolInvokeMessage.BlobMessage):
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage):
return ToolRuntimeMessage.BlobChunkMessage(
id=message.id,
sequence=message.sequence,
total_length=message.total_length,
blob=message.blob,
end=message.end,
)
if isinstance(message, CoreToolInvokeMessage.FileMessage):
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
if isinstance(message, CoreToolInvokeMessage.VariableMessage):
return ToolRuntimeMessage.VariableMessage(
variable_name=message.variable_name,
variable_value=message.variable_value,
stream=message.stream,
)
if isinstance(message, CoreToolInvokeMessage.LogMessage):
return ToolRuntimeMessage.LogMessage(
id=message.id,
label=message.label,
parent_id=message.parent_id,
error=message.error,
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
data=dict(message.data),
metadata=dict(message.metadata),
)
if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage):
retriever_resources = [
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
for resource in message.retriever_resources
]
return ToolRuntimeMessage.RetrieverResourceMessage(
retriever_resources=retriever_resources,
context=message.context,
)
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
match message:
case CoreToolInvokeMessage.TextMessage():
return ToolRuntimeMessage.TextMessage(text=message.text)
case CoreToolInvokeMessage.JsonMessage():
return ToolRuntimeMessage.JsonMessage(
json_object=message.json_object,
suppress_output=message.suppress_output,
)
case CoreToolInvokeMessage.BlobMessage():
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
case CoreToolInvokeMessage.BlobChunkMessage():
return ToolRuntimeMessage.BlobChunkMessage(
id=message.id,
sequence=message.sequence,
total_length=message.total_length,
blob=message.blob,
end=message.end,
)
case CoreToolInvokeMessage.FileMessage():
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
case CoreToolInvokeMessage.VariableMessage():
return ToolRuntimeMessage.VariableMessage(
variable_name=message.variable_name,
variable_value=message.variable_value,
stream=message.stream,
)
case CoreToolInvokeMessage.LogMessage():
return ToolRuntimeMessage.LogMessage(
id=message.id,
label=message.label,
parent_id=message.parent_id,
error=message.error,
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
data=dict(message.data),
metadata=dict(message.metadata),
)
case CoreToolInvokeMessage.RetrieverResourceMessage():
retriever_resources = [
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
for resource in message.retriever_resources
]
return ToolRuntimeMessage.RetrieverResourceMessage(
retriever_resources=retriever_resources,
context=message.context,
)
case _:
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
@staticmethod
def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError:
if isinstance(exc, ToolNodeError):
return exc
if isinstance(exc, PluginInvokeError):
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
if isinstance(exc, PluginDaemonClientSideError):
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
if isinstance(exc, ToolInvokeError):
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
return ToolRuntimeInvocationError(str(exc))
match exc:
case ToolNodeError():
return exc
case PluginInvokeError():
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
case PluginDaemonClientSideError():
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
case ToolInvokeError():
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
case _:
return ToolRuntimeInvocationError(str(exc))
class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):

View File

@ -98,12 +98,13 @@ def extract_tenant_id(user: "Account | EndUser") -> str | None:
from models import Account
from models.model import EndUser
if isinstance(user, Account):
return user.current_tenant_id
elif isinstance(user, EndUser):
return user.tenant_id
else:
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
match user:
case Account():
return user.current_tenant_id
case EndUser():
return user.tenant_id
case _:
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
def run(script):
@ -422,18 +423,19 @@ def length_prefixed_response(
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, Mapping):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
mimetype="application/json",
)
elif isinstance(response, BaseModel):
return Response(
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
status=200,
mimetype="application/json",
)
match response:
case Mapping():
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
mimetype="application/json",
)
case BaseModel():
return Response(
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
status=200,
mimetype="application/json",
)
stream_response = response

View File

@ -48,20 +48,23 @@ def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -
"""Build a graph `File` directly from serialized metadata."""
def _coerce_file_type(value: Any) -> FileType:
if isinstance(value, FileType):
return value
if isinstance(value, str):
return FileType.value_of(value)
raise ValueError("file type is required in file mapping")
match value:
case FileType():
return value
case str():
return FileType.value_of(value)
case _:
raise ValueError("file type is required in file mapping")
mapping = dict(file_mapping)
transfer_method_value = mapping.get("transfer_method")
if isinstance(transfer_method_value, FileTransferMethod):
transfer_method = transfer_method_value
elif isinstance(transfer_method_value, str):
transfer_method = FileTransferMethod.value_of(transfer_method_value)
else:
raise ValueError("transfer_method is required in file mapping")
match transfer_method_value:
case FileTransferMethod():
transfer_method = transfer_method_value
case str():
transfer_method = FileTransferMethod.value_of(transfer_method_value)
case _:
raise ValueError("transfer_method is required in file mapping")
file_id = mapping.get("file_id")
if not isinstance(file_id, str) or not file_id:
@ -151,15 +154,15 @@ def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
so historical JSON blobs remain readable without reintroducing global graph
patches or test-local coercion.
"""
if isinstance(value, list):
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
if isinstance(value, dict):
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
return value
match value:
case list():
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
case dict():
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
case _:
return value
def build_file_from_stored_mapping(