mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
refactor: convert isinstance chains to match/case (part 5) (#36298)
Co-authored-by: Stephen Zhou <hi@hyoban.cc>
This commit is contained in:
parent
c07686928a
commit
7e8147295b
@ -195,22 +195,23 @@ class BaseAppGenerator:
|
||||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
else:
|
||||
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
|
||||
match value:
|
||||
case int() | float():
|
||||
return value
|
||||
case str():
|
||||
# handle empty string case
|
||||
if not value.strip():
|
||||
return None
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
case _:
|
||||
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
|
||||
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
@ -241,17 +242,18 @@ class BaseAppGenerator:
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
case VariableEntityType.CHECKBOX:
|
||||
if isinstance(value, str):
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"true", "1", "yes", "on"}:
|
||||
value = True
|
||||
elif normalized_value in {"false", "0", "no", "off"}:
|
||||
value = False
|
||||
elif isinstance(value, (int, float)):
|
||||
if value == 1:
|
||||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
match value:
|
||||
case str():
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"true", "1", "yes", "on"}:
|
||||
value = True
|
||||
elif normalized_value in {"false", "0", "no", "off"}:
|
||||
value = False
|
||||
case int() | float():
|
||||
if value == 1:
|
||||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if value and not isinstance(value, dict):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
|
||||
|
||||
@ -105,52 +105,31 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._node_sequence = 0
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
self._handle_graph_run_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
self._handle_graph_run_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
self._handle_graph_run_aborted(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
match event:
|
||||
case GraphRunStartedEvent():
|
||||
self._handle_graph_run_started()
|
||||
case GraphRunSucceededEvent():
|
||||
self._handle_graph_run_succeeded(event)
|
||||
case GraphRunPartialSucceededEvent():
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
case GraphRunFailedEvent():
|
||||
self._handle_graph_run_failed(event)
|
||||
case GraphRunAbortedEvent():
|
||||
self._handle_graph_run_aborted(event)
|
||||
case GraphRunPausedEvent():
|
||||
self._handle_graph_run_paused(event)
|
||||
case NodeRunRetryEvent():
|
||||
self._handle_node_retry(event)
|
||||
case NodeRunStartedEvent():
|
||||
self._handle_node_started(event)
|
||||
case NodeRunSucceededEvent():
|
||||
self._handle_node_succeeded(event)
|
||||
case NodeRunFailedEvent():
|
||||
self._handle_node_failed(event)
|
||||
case NodeRunExceptionEvent():
|
||||
self._handle_node_exception(event)
|
||||
case NodeRunPauseRequestedEvent():
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
@ -288,11 +288,13 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
match temp_parsed:
|
||||
case dict():
|
||||
pass
|
||||
case list():
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
case _:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
@ -341,12 +343,13 @@ def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
match value:
|
||||
case dict():
|
||||
remove_additional_properties(value)
|
||||
case list():
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
@ -364,9 +367,10 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
match value:
|
||||
case dict():
|
||||
convert_boolean_to_string(value)
|
||||
case list():
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
||||
@ -608,64 +608,67 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
|
||||
|
||||
if isinstance(message, CoreToolInvokeMessage.TextMessage):
|
||||
return ToolRuntimeMessage.TextMessage(text=message.text)
|
||||
if isinstance(message, CoreToolInvokeMessage.JsonMessage):
|
||||
return ToolRuntimeMessage.JsonMessage(
|
||||
json_object=message.json_object,
|
||||
suppress_output=message.suppress_output,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.BlobMessage):
|
||||
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
|
||||
if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage):
|
||||
return ToolRuntimeMessage.BlobChunkMessage(
|
||||
id=message.id,
|
||||
sequence=message.sequence,
|
||||
total_length=message.total_length,
|
||||
blob=message.blob,
|
||||
end=message.end,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.FileMessage):
|
||||
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
|
||||
if isinstance(message, CoreToolInvokeMessage.VariableMessage):
|
||||
return ToolRuntimeMessage.VariableMessage(
|
||||
variable_name=message.variable_name,
|
||||
variable_value=message.variable_value,
|
||||
stream=message.stream,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.LogMessage):
|
||||
return ToolRuntimeMessage.LogMessage(
|
||||
id=message.id,
|
||||
label=message.label,
|
||||
parent_id=message.parent_id,
|
||||
error=message.error,
|
||||
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
|
||||
data=dict(message.data),
|
||||
metadata=dict(message.metadata),
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage):
|
||||
retriever_resources = [
|
||||
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
|
||||
for resource in message.retriever_resources
|
||||
]
|
||||
return ToolRuntimeMessage.RetrieverResourceMessage(
|
||||
retriever_resources=retriever_resources,
|
||||
context=message.context,
|
||||
)
|
||||
|
||||
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
|
||||
match message:
|
||||
case CoreToolInvokeMessage.TextMessage():
|
||||
return ToolRuntimeMessage.TextMessage(text=message.text)
|
||||
case CoreToolInvokeMessage.JsonMessage():
|
||||
return ToolRuntimeMessage.JsonMessage(
|
||||
json_object=message.json_object,
|
||||
suppress_output=message.suppress_output,
|
||||
)
|
||||
case CoreToolInvokeMessage.BlobMessage():
|
||||
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
|
||||
case CoreToolInvokeMessage.BlobChunkMessage():
|
||||
return ToolRuntimeMessage.BlobChunkMessage(
|
||||
id=message.id,
|
||||
sequence=message.sequence,
|
||||
total_length=message.total_length,
|
||||
blob=message.blob,
|
||||
end=message.end,
|
||||
)
|
||||
case CoreToolInvokeMessage.FileMessage():
|
||||
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
|
||||
case CoreToolInvokeMessage.VariableMessage():
|
||||
return ToolRuntimeMessage.VariableMessage(
|
||||
variable_name=message.variable_name,
|
||||
variable_value=message.variable_value,
|
||||
stream=message.stream,
|
||||
)
|
||||
case CoreToolInvokeMessage.LogMessage():
|
||||
return ToolRuntimeMessage.LogMessage(
|
||||
id=message.id,
|
||||
label=message.label,
|
||||
parent_id=message.parent_id,
|
||||
error=message.error,
|
||||
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
|
||||
data=dict(message.data),
|
||||
metadata=dict(message.metadata),
|
||||
)
|
||||
case CoreToolInvokeMessage.RetrieverResourceMessage():
|
||||
retriever_resources = [
|
||||
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
|
||||
for resource in message.retriever_resources
|
||||
]
|
||||
return ToolRuntimeMessage.RetrieverResourceMessage(
|
||||
retriever_resources=retriever_resources,
|
||||
context=message.context,
|
||||
)
|
||||
case _:
|
||||
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
|
||||
|
||||
@staticmethod
|
||||
def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError:
|
||||
if isinstance(exc, ToolNodeError):
|
||||
return exc
|
||||
if isinstance(exc, PluginInvokeError):
|
||||
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
|
||||
if isinstance(exc, PluginDaemonClientSideError):
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
|
||||
if isinstance(exc, ToolInvokeError):
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
|
||||
return ToolRuntimeInvocationError(str(exc))
|
||||
match exc:
|
||||
case ToolNodeError():
|
||||
return exc
|
||||
case PluginInvokeError():
|
||||
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
|
||||
case PluginDaemonClientSideError():
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
|
||||
case ToolInvokeError():
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
|
||||
case _:
|
||||
return ToolRuntimeInvocationError(str(exc))
|
||||
|
||||
|
||||
class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
|
||||
|
||||
@ -98,12 +98,13 @@ def extract_tenant_id(user: "Account | EndUser") -> str | None:
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
if isinstance(user, Account):
|
||||
return user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
return user.tenant_id
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
|
||||
match user:
|
||||
case Account():
|
||||
return user.current_tenant_id
|
||||
case EndUser():
|
||||
return user.tenant_id
|
||||
case _:
|
||||
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
|
||||
|
||||
|
||||
def run(script):
|
||||
@ -422,18 +423,19 @@ def length_prefixed_response(
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
elif isinstance(response, BaseModel):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
match response:
|
||||
case Mapping():
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
case BaseModel():
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
stream_response = response
|
||||
|
||||
|
||||
@ -48,20 +48,23 @@ def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -
|
||||
"""Build a graph `File` directly from serialized metadata."""
|
||||
|
||||
def _coerce_file_type(value: Any) -> FileType:
|
||||
if isinstance(value, FileType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return FileType.value_of(value)
|
||||
raise ValueError("file type is required in file mapping")
|
||||
match value:
|
||||
case FileType():
|
||||
return value
|
||||
case str():
|
||||
return FileType.value_of(value)
|
||||
case _:
|
||||
raise ValueError("file type is required in file mapping")
|
||||
|
||||
mapping = dict(file_mapping)
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if isinstance(transfer_method_value, FileTransferMethod):
|
||||
transfer_method = transfer_method_value
|
||||
elif isinstance(transfer_method_value, str):
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
else:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
match transfer_method_value:
|
||||
case FileTransferMethod():
|
||||
transfer_method = transfer_method_value
|
||||
case str():
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
case _:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
|
||||
file_id = mapping.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
@ -151,15 +154,15 @@ def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
|
||||
so historical JSON blobs remain readable without reintroducing global graph
|
||||
patches or test-local coercion.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
if maybe_file_object(value):
|
||||
return build_file_from_mapping_without_lookup(file_mapping=value)
|
||||
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
|
||||
|
||||
return value
|
||||
match value:
|
||||
case list():
|
||||
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
|
||||
case dict():
|
||||
if maybe_file_object(value):
|
||||
return build_file_from_mapping_without_lookup(file_mapping=value)
|
||||
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
|
||||
case _:
|
||||
return value
|
||||
|
||||
|
||||
def build_file_from_stored_mapping(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user