diff --git a/api/controllers/common/human_input.py b/api/controllers/common/human_input.py index d9b8f8f9a37..4b2e70e2d03 100644 --- a/api/controllers/common/human_input.py +++ b/api/controllers/common/human_input.py @@ -42,10 +42,11 @@ def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]: """Serialize default values into strings expected by human-input form clients.""" result: dict[str, str] = {} for key, value in values.items(): - if value is None: - result[key] = "" - elif isinstance(value, (dict, list)): - result[key] = json.dumps(value, ensure_ascii=False) - else: - result[key] = str(value) + match value: + case None: + result[key] = "" + case dict() | list(): + result[key] = json.dumps(value, ensure_ascii=False) + case _: + result[key] = str(value) return result diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 83442b316f9..234488885de 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -182,14 +182,15 @@ register_response_schema_models( def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: status = _enum_value(workflow_run.status) raw_outputs = workflow_run.outputs_dict - if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: - outputs: dict = {} - elif isinstance(raw_outputs, dict): - outputs = raw_outputs - elif isinstance(raw_outputs, Mapping): - outputs = dict(raw_outputs) - else: - outputs = {} + match raw_outputs: + case _ if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + case dict(): + outputs = raw_outputs + case _ if isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + case _: + outputs = {} return WorkflowRunResponse.model_validate( { "id": workflow_run.id, diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index a89a0cf70db..7b854fec34a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -231,22 +231,23 @@ class AppRunner: :param tenant_id: tenant id for multimodal output :return: """ - if not stream and isinstance(invoke_result, LLMResult): - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - ) - elif stream and isinstance(invoke_result, Generator): - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent, - message_id=message_id, - user_id=user_id, - tenant_id=tenant_id, - ) - else: - raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") + match invoke_result: + case LLMResult() if not stream: + self._handle_invoke_result_direct( + invoke_result=invoke_result, + queue_manager=queue_manager, + ) + case _ if stream and isinstance(invoke_result, Generator): + self._handle_invoke_result_stream( + invoke_result=invoke_result, + queue_manager=queue_manager, + agent=agent, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + ) + case _: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 502b1907ba4..c9486b5821f 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -882,7 +882,7 @@ class WorkflowResponseConverter: return files @classmethod - def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None: + def _get_file_var_from_value(cls, value: object) -> Mapping[str, Any] | None: """ Get file var from value :param value: variable value @@ -891,10 +891,11 @@ class WorkflowResponseConverter: if not value: return None - if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - return value - elif isinstance(value, File): - return value.to_dict() + match value: + case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + case File(): + return value.to_dict() return None diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 82b5f42885a..8b022c1d065 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -144,15 +144,16 @@ def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, Returns an empty dict if the context is missing or incomplete. """ parent_trace_context = args.get("parent_trace_context") - if isinstance(parent_trace_context, ParentTraceContext): - context = parent_trace_context - elif isinstance(parent_trace_context, Mapping): - try: - context = ParentTraceContext.model_validate(parent_trace_context) - except ValidationError: + match parent_trace_context: + case ParentTraceContext(): + context = parent_trace_context + case Mapping(): + try: + context = ParentTraceContext.model_validate(parent_trace_context) + except ValidationError: + return {} + case _: return {} - else: - return {} if context.parent_node_execution_id is None: return {} diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index ba305690664..14ed8af3ef7 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -116,20 +116,21 @@ def cast_parameter_value(typ: StrEnum, value: Any, /): return value if isinstance(value, str) else str(value) case PluginParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) + match value: + case None: + return False + case str(): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + case _: + return value if isinstance(value, bool) else bool(value) case PluginParameterType.NUMBER: match value: diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index cd86aebc060..6d959e0e87a 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -304,22 +304,23 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool: Returns: True if any Dify $ref is found, False otherwise """ - if isinstance(schema, dict): - # Check if this dict has a $ref field - ref_uri = schema.get("$ref") - if ref_uri and _is_dify_schema_ref(ref_uri): - return True - - # Check nested values - for value in schema.values(): - if _has_dify_refs_recursive(value): + match schema: + case dict(): + # Check if this dict has a $ref field + ref_uri = schema.get("$ref") + if ref_uri and _is_dify_schema_ref(ref_uri): return True - elif isinstance(schema, list): - # Check each item in the list - for item in schema: - if _has_dify_refs_recursive(item): - return True + # Check nested values + for value in schema.values(): + if _has_dify_refs_recursive(value): + return True + + case list(): + # Check each item in the list + for item in schema: + if _has_dify_refs_recursive(item): + return True # Primitive types don't contain refs return False diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 1ebb7ab3a7f..57363349458 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from datetime import datetime -from typing import Any, override +from datetime import datetime, tzinfo +from typing import Any, cast, override import pytz # type: ignore[import-untyped] @@ -35,17 +35,26 @@ class LocaltimeToTimestampTool(BuiltinTool): yield self.create_text_message(f"{timestamp}") - # TODO: this method's type is messy @staticmethod - def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: + def localtime_to_timestamp(localtime: str, time_format: str, local_tz: str | tzinfo | None = None) -> int | None: try: local_time = datetime.strptime(localtime, time_format) - if local_tz is None: - localtime = local_time.astimezone() # type: ignore - elif isinstance(local_tz, str): - local_tz = pytz.timezone(local_tz) - localtime = local_tz.localize(local_time) # type: ignore - timestamp = int(localtime.timestamp()) # type: ignore + converted_localtime: datetime + match local_tz: + case None: + converted_localtime = local_time.astimezone() + case str() as timezone_name: + timezone = pytz.timezone(timezone_name) + converted_localtime = timezone.localize(local_time) + case tzinfo(): + localize = getattr(local_tz, "localize", None) + if callable(localize): + converted_localtime = cast(datetime, localize(local_time)) + else: + converted_localtime = local_time.replace(tzinfo=local_tz) + case _: + raise ValueError("local_tz must be None, a timezone name, or a tzinfo instance") + timestamp = int(converted_localtime.timestamp()) return timestamp except Exception as e: raise ToolInvokeError(str(e)) diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 7a1553a4b15..195acd6e1ad 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -122,13 +122,14 @@ class MCPTool(Tool): def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: """Process JSON content based on its type.""" - if isinstance(content_json, dict): - yield self.create_json_message(content_json) - elif isinstance(content_json, list): - yield from self._process_json_list(content_json) - else: - # For primitive types (str, int, bool, etc.), convert to string - yield self.create_text_message(str(content_json)) + match content_json: + case dict(): + yield self.create_json_message(content_json) + case list(): + yield from self._process_json_list(content_json) + case _: + # For primitive types (str, int, bool, etc.), convert to string + yield self.create_text_message(str(content_json)) def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: """Process a list of JSON items.""" @@ -222,16 +223,17 @@ class MCPTool(Tool): # Recursively search through nested structures for value in payload.values(): - if isinstance(value, Mapping): - found = cls._extract_usage_dict(value) - if found is not None: - return found - elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)): - for item in value: - if isinstance(item, Mapping): - found = cls._extract_usage_dict(item) - if found is not None: - return found + match value: + case _ if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + case list() if not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found return None @override diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 7e7b1e33008..97222f3cfae 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -196,16 +196,17 @@ class WorkflowTool(Tool): return usage_candidate for value in payload.values(): - if isinstance(value, Mapping): - found = cls._extract_usage_dict(value) - if found is not None: - return found - elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - for item in value: - if isinstance(item, Mapping): - found = cls._extract_usage_dict(item) - if found is not None: - return found + match value: + case _ if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + case _ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found return None @override @@ -393,24 +394,25 @@ class WorkflowTool(Tool): files: list[File] = [] result = {} for key, value in outputs.items(): - if isinstance(value, list): - for item in value: - if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: - item = self._update_file_mapping(item) - file = build_from_mapping( - mapping=item, - tenant_id=str(self.runtime.tenant_id), - access_controller=_file_access_controller, - ) - files.append(file) - elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - value = self._update_file_mapping(value) - file = build_from_mapping( - mapping=value, - tenant_id=str(self.runtime.tenant_id), - access_controller=_file_access_controller, - ) - files.append(file) + match value: + case list(): + for item in value: + if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: + item = self._update_file_mapping(item) + file = build_from_mapping( + mapping=item, + tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, + ) + files.append(file) + case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + value = self._update_file_mapping(value) + file = build_from_mapping( + mapping=value, + tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, + ) + files.append(file) result[key] = value diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1d938dd04c5..6f1660390d3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -297,25 +297,26 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD for cond in conditions.conditions or []: value = cond.value resolved_value: str | Sequence[str] | int | float | None - if isinstance(value, str): - segment_group = variable_pool.convert_template(value) - if len(segment_group.value) == 1: - resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) - else: - resolved_value = segment_group.text - elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values: list[str] = [] - for v in value: - segment_group = variable_pool.convert_template(v) + match value: + case str(): + segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_values.append( - _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) - ) + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: - resolved_values.append(segment_group.text) - resolved_value = resolved_values - else: - resolved_value = value + resolved_value = segment_group.text + case _ if isinstance(value, Sequence) and all(isinstance(v, str) for v in value): + resolved_values: list[str] = [] + for v in value: + segment_group = variable_pool.convert_template(v) + if len(segment_group.value) == 1: + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) + else: + resolved_values.append(segment_group.text) + resolved_value = resolved_values + case _: + resolved_value = value resolved_conditions.append( Condition( name=cond.name, diff --git a/api/dev/lint_response_contracts.py b/api/dev/lint_response_contracts.py index 4ba79e0fedb..75c5f67b8ff 100644 --- a/api/dev/lint_response_contracts.py +++ b/api/dev/lint_response_contracts.py @@ -354,11 +354,12 @@ def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]: def target_names(target: ast.AST) -> Iterable[str]: - if isinstance(target, ast.Name): - yield target.id - elif isinstance(target, ast.Tuple | ast.List): - for item in target.elts: - yield from target_names(item) + match target: + case ast.Name(): + yield target.id + case ast.Tuple() | ast.List(): + for item in target.elts: + yield from target_names(item) def record_assignment( diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 15355a77620..5af42d12538 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -94,12 +94,13 @@ class RedisSubscriptionBase(Subscription): continue channel_field = raw_message.get("channel") - if isinstance(channel_field, bytes): - channel_name = channel_field.decode("utf-8") - elif isinstance(channel_field, str): - channel_name = channel_field - else: - channel_name = str(channel_field) + match channel_field: + case bytes(): + channel_name = channel_field.decode("utf-8") + case str(): + channel_name = channel_field + case _: + channel_name = str(channel_field) if channel_name != self._topic: _logger.warning( diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index a7303c07823..68e9f8b23ef 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -88,22 +88,23 @@ class _RedisShardedSubscription(RedisSubscriptionBase): # # Since we have already filtered at the caller's site, we can safely set # `ignore_subscribe_messages=False`. - if isinstance(self._client, RedisCluster): - # NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without - # specifying the `target_node` argument would use busy-looping to wait - # for incoming message, consuming excessive CPU quota. - # - # Here we specify the `target_node` to mitigate this problem. - node = self._client.get_node_from_key(self._topic) - return self._pubsub.get_sharded_message( # type: ignore[attr-defined] - ignore_subscribe_messages=False, - timeout=1, - target_node=node, - ) - elif isinstance(self._client, Redis): - return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined] - else: - raise AssertionError("client should be either Redis or RedisCluster.") + match self._client: + case RedisCluster(): + # NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without + # specifying the `target_node` argument would use busy-looping to wait + # for incoming message, consuming excessive CPU quota. + # + # Here we specify the `target_node` to mitigate this problem. + node = self._client.get_node_from_key(self._topic) + return self._pubsub.get_sharded_message( # type: ignore[attr-defined] + ignore_subscribe_messages=False, + timeout=1, + target_node=node, + ) + case Redis(): + return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined] + case _: + raise AssertionError("client should be either Redis or RedisCluster.") @override def _get_message_type(self) -> str: diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 30c14585793..62e58798ab3 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -138,10 +138,11 @@ class _StreamsSubscription(Subscription): if isinstance(fields, dict): data = fields.get(b"data") data_bytes: bytes | None = None - if isinstance(data, str): - data_bytes = data.encode() - elif isinstance(data, (bytes, bytearray)): - data_bytes = bytes(data) + match data: + case str(): + data_bytes = data.encode() + case bytes() | bytearray(): + data_bytes = bytes(data) if data_bytes is not None: self._queue.put_nowait(data_bytes) last_id = entry_id diff --git a/api/models/model.py b/api/models/model.py index 09809b85f6b..69d2a4a7f19 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1174,34 +1174,32 @@ class Conversation(Base): # Convert file mapping to File object for key, value in inputs.items(): - if ( - isinstance(value, dict) - and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY - ): - value_dict = cast(dict[str, Any], value) - inputs[key] = build_file_from_input_mapping( - file_mapping=value_dict, - tenant_resolver=tenant_resolver, - ) - elif isinstance(value, list): - value_list = value - if all( - isinstance(item, dict) - and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY - for item in value_list - ): - file_list: list[File] = [] - for item in value_list: - if not isinstance(item, dict): - continue - item_dict = cast(dict[str, Any], item) - file_list.append( - build_file_from_input_mapping( - file_mapping=item_dict, - tenant_resolver=tenant_resolver, + match value: + case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY: + value_dict = cast(dict[str, Any], value) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) + case list(): + value_list = value + if all( + isinstance(item, dict) + and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY + for item in value_list + ): + file_list: list[File] = [] + for item in value_list: + if not isinstance(item, dict): + continue + item_dict = cast(dict[str, Any], item) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) ) - ) - inputs[key] = file_list + inputs[key] = file_list return inputs @@ -1516,46 +1514,45 @@ class Message(Base): owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), ) for key, value in inputs.items(): - if ( - isinstance(value, dict) - and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY - ): - value_dict = cast(dict[str, Any], value) - inputs[key] = build_file_from_input_mapping( - file_mapping=value_dict, - tenant_resolver=tenant_resolver, - ) - elif isinstance(value, list): - value_list = value - if all( - isinstance(item, dict) - and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY - for item in value_list - ): - file_list: list[File] = [] - for item in value_list: - if not isinstance(item, dict): - continue - item_dict = cast(dict[str, Any], item) - file_list.append( - build_file_from_input_mapping( - file_mapping=item_dict, - tenant_resolver=tenant_resolver, + match value: + case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY: + value_dict = cast(dict[str, Any], value) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) + case list(): + value_list = value + if all( + isinstance(item, dict) + and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY + for item in value_list + ): + file_list: list[File] = [] + for item in value_list: + if not isinstance(item, dict): + continue + item_dict = cast(dict[str, Any], item) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) ) - ) - inputs[key] = file_list + inputs[key] = file_list return inputs @inputs.setter def inputs(self, value: Mapping[str, Any]): inputs = dict(value) for k, v in inputs.items(): - if isinstance(v, File): - inputs[k] = v.model_dump() - elif isinstance(v, list): - v_list = v - if all(isinstance(item, File) for item in v_list): - inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] + match v: + case File(): + inputs[k] = v.model_dump() + case list(): + v_list = v + if all(isinstance(item, File) for item in v_list): + inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] self._inputs = inputs @property diff --git a/api/models/types.py b/api/models/types.py index 092db638565..c5a9231ad4a 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -96,12 +96,13 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None: if value is None: return None - if isinstance(value, self._model_class): - model = value - elif isinstance(value, str): - model = self._model_class.model_validate_json(value) - else: - model = self._model_class.model_validate(value) + match value: + case _ if isinstance(value, self._model_class): + model = value + case str(): + model = self._model_class.model_validate_json(value) + case _: + model = self._model_class.model_validate(value) return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":")) @override diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 0f24adfd92e..9f779fc43de 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -1150,13 +1150,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): try: # Convert outputs to string based on type outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value - if isinstance(trace_info.outputs, dict | list): - outputs_str = safe_json_dumps(trace_info.outputs) - outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value - elif isinstance(trace_info.outputs, str): - outputs_str = trace_info.outputs - else: - outputs_str = str(trace_info.outputs) + match trace_info.outputs: + case dict() | list(): + outputs_str = safe_json_dumps(trace_info.outputs) + outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value + case str(): + outputs_str = trace_info.outputs + case _: + outputs_str = str(trace_info.outputs) llm_attributes: dict[str, Any] = { SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, @@ -1553,25 +1554,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance): set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id) # Handle list of messages - if isinstance(prompts, list): - for message_index, message in enumerate(prompts): - if not isinstance(message, dict): - continue + match prompts: + case list(): + for message_index, message in enumerate(prompts): + if not isinstance(message, dict): + continue - role = message.get("role", "user") - content = message.get("text") or message.get("content") or "" + role = message.get("role", "user") + content = message.get("text") or message.get("content") or "" - set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role) - set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content) + set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role) + set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content) - tool_calls = message.get("tool_calls") or [] - if isinstance(tool_calls, list): - for tool_index, tool_call in enumerate(tool_calls): - set_tool_call_attributes(message_index, tool_index, tool_call) + tool_calls = message.get("tool_calls") or [] + if isinstance(tool_calls, list): + for tool_index, tool_call in enumerate(tool_calls): + set_tool_call_attributes(message_index, tool_index, tool_call) - # Handle single dict or plain string prompt - elif isinstance(prompts, (dict, str)): - set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts) - set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user") + # Handle single dict or plain string prompt + case dict() | str(): + set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts) + set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user") return attributes diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py index 76755bf7693..742938f09f4 100644 --- a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py @@ -18,24 +18,25 @@ def validate_input_output(v, field_name): """ if v == {} or v is None: return v - if isinstance(v, str): - return [ - { - "role": "assistant" if field_name == "output" else "user", - "content": v, - } - ] - elif isinstance(v, list): - if len(v) > 0 and isinstance(v[0], dict): - v = replace_text_with_content(data=v) - return v - else: + match v: + case str(): return [ { "role": "assistant" if field_name == "output" else "user", - "content": str(v), + "content": v, } ] + case list(): + if len(v) > 0 and isinstance(v[0], dict): + v = replace_text_with_content(data=v) + return v + else: + return [ + { + "role": "assistant" if field_name == "output" else "user", + "content": str(v), + } + ] return v diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py index be9d64ae018..07159d8a7e3 100644 --- a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py @@ -64,40 +64,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) - if isinstance(v, str): - match field_name: - case "inputs": - return { - "messages": { - "role": "user", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - case "outputs": - return { - "choices": { - "role": "ai", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - case _: - pass - elif isinstance(v, list): - data = {} - if len(v) > 0 and isinstance(v[0], dict): - # rename text to content - v = replace_text_with_content(data=v) + match v: + case str(): match field_name: case "inputs": - data = { - "messages": v, + return { + "messages": { + "role": "user", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, } case "outputs": - data = { + return { "choices": { "role": "ai", "content": v, @@ -107,16 +87,37 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): } case _: pass - return data - else: - return { - "choices": { - "role": "ai" if field_name == "outputs" else "user", - "content": str(v), - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } + case list(): + data = {} + if len(v) > 0 and isinstance(v[0], dict): + # rename text to content + v = replace_text_with_content(data=v) + match field_name: + case "inputs": + data = { + "messages": v, + } + case "outputs": + data = { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + case _: + pass + return data + else: + return { + "choices": { + "role": "ai" if field_name == "outputs" else "user", + "content": str(v), + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } if isinstance(v, dict): v["usage_metadata"] = usage_metadata v["file_list"] = file_list diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py index ed6a7dabbb0..98180c80a2c 100644 --- a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py @@ -40,41 +40,19 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) - if isinstance(v, str): - if field_name == "inputs": - return { - "messages": { - "role": "user", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif field_name == "outputs": - return { - "choices": { - "role": "ai", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif isinstance(v, list): - data = {} - if len(v) > 0 and isinstance(v[0], dict): - # rename text to content - v = replace_text_with_content(data=v) + match v: + case str(): if field_name == "inputs": - data = { - "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore - for msg in v - ] - if isinstance(v, list) - else v, + return { + "messages": { + "role": "user", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, } elif field_name == "outputs": - data = { + return { "choices": { "role": "ai", "content": v, @@ -82,16 +60,39 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): "file_list": file_list, }, } - return data - else: - return { - "choices": { - "role": "ai" if field_name == "outputs" else "user", - "content": str(v), - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } + case list(): + data = {} + if len(v) > 0 and isinstance(v[0], dict): + # rename text to content + v = replace_text_with_content(data=v) + if field_name == "inputs": + data = { + "messages": [ + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v + ] + if isinstance(v, list) + else v, + } + elif field_name == "outputs": + data = { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + return data + else: + return { + "choices": { + "role": "ai" if field_name == "outputs" else "user", + "content": str(v), + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } if isinstance(v, dict): v["usage_metadata"] = usage_metadata v["file_list"] = file_list diff --git a/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py index 6231b9a9fad..6e80c6efa2c 100644 --- a/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py @@ -361,12 +361,13 @@ class ClickzettaVector(BaseVector): first_pass = json.loads(raw_metadata) # Handle double-encoded JSON - if isinstance(first_pass, str): - metadata = parse_metadata_json(first_pass) - elif isinstance(first_pass, dict): - metadata = first_pass - else: - metadata = {} + match first_pass: + case str(): + metadata = parse_metadata_json(first_pass) + case dict(): + metadata = first_pass + case _: + metadata = {} else: metadata = {} except (json.JSONDecodeError, ValueError, TypeError): @@ -942,12 +943,13 @@ class ClickzettaVector(BaseVector): # First parse may yield a string (double-encoded JSON) first_pass = json.loads(row[2]) - if isinstance(first_pass, str): - metadata = parse_metadata_json(first_pass) - elif isinstance(first_pass, dict): - metadata = first_pass - else: - metadata = {} + match first_pass: + case str(): + metadata = parse_metadata_json(first_pass) + case dict(): + metadata = first_pass + case _: + metadata = {} else: metadata = {} except (json.JSONDecodeError, ValueError, TypeError): diff --git a/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py index d28ded01873..290eda67e94 100644 --- a/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py @@ -98,14 +98,15 @@ def _extract_identifiers_and_literals(query) -> list[Any]: values: list[Any] = [] if isinstance(query, psql.Composed): for part in query: - if isinstance(part, psql.Identifier): - values.append(("ident", part._obj[0] if part._obj else "")) - elif isinstance(part, psql.Literal): - values.append(("literal", part._obj)) - elif isinstance(part, psql.Composed): - for sub in part: - if isinstance(sub, psql.Literal): - values.append(("literal", sub._obj)) + match part: + case psql.Identifier(): + values.append(("ident", part._obj[0] if part._obj else "")) + case psql.Literal(): + values.append(("literal", part._obj)) + case psql.Composed(): + for sub in part: + if isinstance(sub, psql.Literal): + values.append(("literal", sub._obj)) return values diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 98c605f0a17..b40eb4bdd8a 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -957,21 +957,22 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) pause_reason_models = [] for reason in pause_reasons: - if isinstance(reason, HumanInputRequired): - # TODO(QuantumGhost): record node_id for `WorkflowPauseReason` - pause_reason_model = WorkflowPauseReason( - pause_id=pause_model.id, - type_=reason.TYPE, - form_id=reason.form_id, - ) - elif isinstance(reason, SchedulingPause): - pause_reason_model = WorkflowPauseReason( - pause_id=pause_model.id, - type_=reason.TYPE, - message=reason.message, - ) - else: - raise AssertionError(f"unkown reason type: {type(reason)}") + match reason: + case HumanInputRequired(): + # TODO(QuantumGhost): record node_id for `WorkflowPauseReason` + pause_reason_model = WorkflowPauseReason( + pause_id=pause_model.id, + type_=reason.TYPE, + form_id=reason.form_id, + ) + case SchedulingPause(): + pause_reason_model = WorkflowPauseReason( + pause_id=pause_model.id, + type_=reason.TYPE, + message=reason.message, + ) + case _: + raise AssertionError(f"unknown reason type: {type(reason)}") pause_reason_models.append(pause_reason_model) diff --git a/api/services/agent/composer_validator.py b/api/services/agent/composer_validator.py index 8554b5c1ab7..b9519272c4a 100644 --- a/api/services/agent/composer_validator.py +++ b/api/services/agent/composer_validator.py @@ -334,16 +334,17 @@ class ComposerConfigValidator: @classmethod def _reject_plaintext_secrets(cls, value: Any, *, path: str) -> None: - if isinstance(value, dict): - for key, nested in value.items(): - normalized_key = key.lower().replace("-", "_") - nested_path = f"{path}.{key}" - if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested: - raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}") - cls._reject_plaintext_secrets(nested, path=nested_path) - elif isinstance(value, list): - for index, nested in enumerate(value): - cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]") + match value: + case dict(): + for key, nested in value.items(): + normalized_key = key.lower().replace("-", "_") + nested_path = f"{path}.{key}" + if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested: + raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}") + cls._reject_plaintext_secrets(nested, path=nested_path) + case list(): + for index, nested in enumerate(value): + cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]") @classmethod def _has_install_command(cls, entry: dict[str, Any]) -> bool: diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py index dad7163739f..4129613b2fe 100644 --- a/api/services/attachment_service.py +++ b/api/services/attachment_service.py @@ -14,12 +14,13 @@ class AttachmentService: _session_maker: sessionmaker def __init__(self, session_factory: sessionmaker | Engine | None = None): - if isinstance(session_factory, Engine): - self._session_maker = sessionmaker(bind=session_factory) - elif isinstance(session_factory, sessionmaker): - self._session_maker = session_factory - else: - raise AssertionError("must be a sessionmaker or an Engine.") + match session_factory: + case Engine(): + self._session_maker = sessionmaker(bind=session_factory) + case sessionmaker(): + self._session_maker = session_factory + case _: + raise AssertionError("must be a sessionmaker or an Engine.") def get_file_base64(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: diff --git a/api/services/file_service.py b/api/services/file_service.py index 4d3afcc9ad4..1781f0c9727 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -39,12 +39,13 @@ class FileService: _session_maker: sessionmaker[Session] def __init__(self, session_factory: sessionmaker | Engine | None = None): - if isinstance(session_factory, Engine): - self._session_maker = sessionmaker(bind=session_factory) - elif isinstance(session_factory, sessionmaker): - self._session_maker = session_factory - else: - raise AssertionError("must be a sessionmaker or an Engine.") + match session_factory: + case Engine(): + self._session_maker = sessionmaker(bind=session_factory) + case sessionmaker(): + self._session_maker = session_factory + case _: + raise AssertionError("must be a sessionmaker or an Engine.") def upload_file( self, diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index c266d4f9586..995cb94c633 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -119,10 +119,11 @@ class HumanInputDeliveryTestService: class EmailDeliveryTestHandler: def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None: - if session_factory is None: - session_factory = sessionmaker(bind=db.engine) - elif isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) + match session_factory: + case None: + session_factory = sessionmaker(bind=db.engine) + case Engine(): + session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory def supports(self, method: DeliveryChannelConfig) -> bool: @@ -179,11 +180,12 @@ class EmailDeliveryTestHandler: emails: list[str] = [] bound_reference_ids: list[str] = [] for recipient in recipients.items: - if isinstance(recipient, MemberRecipient): - bound_reference_ids.append(recipient.reference_id) - elif isinstance(recipient, ExternalRecipient): - if recipient.email: - emails.append(recipient.email) + match recipient: + case MemberRecipient(): + bound_reference_ids.append(recipient.reference_id) + case ExternalRecipient(): + if recipient.email: + emails.append(recipient.email) if recipients.include_bound_group: bound_reference_ids = [] diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 7c9bd489bda..9c30f32a7ad 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -125,10 +125,11 @@ class DraftVarLoader(VariableLoader): # can be safely accessed before any offloading logic is applied. for draft_var in draft_vars: value = draft_var.get_value() - if isinstance(value, FileSegment): - files.append(value.value) - elif isinstance(value, ArrayFileSegment): - files.extend(value.value) + match value: + case FileSegment(): + files.append(value.value) + case ArrayFileSegment(): + files.extend(value.value) with Session(bind=self._engine) as session: storage_key_loader = StorageKeyLoader( session, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 29b9e72a009..2499e6cc094 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -34,10 +34,11 @@ class WorkflowRunService: def __init__(self, session_factory: Engine | sessionmaker | None = None): """Initialize WorkflowRunService with repository dependencies.""" - if session_factory is None: - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - elif isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + match session_factory: + case None: + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + case Engine(): + session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) self._session_factory = session_factory self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( diff --git a/api/tests/helpers/legacy_model_type_migration.py b/api/tests/helpers/legacy_model_type_migration.py index 12f092a0fe3..4140119f326 100644 --- a/api/tests/helpers/legacy_model_type_migration.py +++ b/api/tests/helpers/legacy_model_type_migration.py @@ -131,10 +131,11 @@ def fetch_table_rows( for row in rows: normalized = dict(row) for key, value in normalized.items(): - if isinstance(value, datetime): - normalized[key] = value.isoformat() - elif isinstance(value, uuid.UUID): - normalized[key] = str(value) + match value: + case datetime(): + normalized[key] = value.isoformat() + case uuid.UUID(): + normalized[key] = str(value) result.append(normalized) return result diff --git a/api/tests/unit_tests/commands/test_generate_swagger_specs.py b/api/tests/unit_tests/commands/test_generate_swagger_specs.py index d333dd50ea7..c30386c9d65 100644 --- a/api/tests/unit_tests/commands/test_generate_swagger_specs.py +++ b/api/tests/unit_tests/commands/test_generate_swagger_specs.py @@ -8,12 +8,13 @@ from pathlib import Path def _walk_values(value): yield value - if isinstance(value, dict): - for child in value.values(): - yield from _walk_values(child) - elif isinstance(value, list): - for child in value: - yield from _walk_values(child) + match value: + case dict(): + for child in value.values(): + yield from _walk_values(child) + case list(): + for child in value: + yield from _walk_values(child) def _load_generate_swagger_specs_module(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 3f6b1ec1545..a99fd1f248f 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -4,6 +4,7 @@ import calendar import math from datetime import date from types import SimpleNamespace +from zoneinfo import ZoneInfo import pytest @@ -66,6 +67,20 @@ def test_localtime_to_timestamp_tool(): ts_value = float(ts_message.strip()) assert math.isfinite(ts_value) assert ts_value >= 0 + assert ( + LocaltimeToTimestampTool.localtime_to_timestamp( + "2024-01-01 10:00:00", + "%Y-%m-%d %H:%M:%S", + ZoneInfo("UTC"), + ) + == 1704103200 + ) + with pytest.raises(ToolInvokeError, match="local_tz must be"): + LocaltimeToTimestampTool.localtime_to_timestamp( + "2024-01-01 10:00:00", + "%Y-%m-%d %H:%M:%S", + object(), # type: ignore[arg-type] + ) with pytest.raises(ToolInvokeError): LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC") diff --git a/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py index 4029edfb686..da699ef6101 100644 --- a/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py +++ b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py @@ -459,23 +459,24 @@ class TestEndToEndSerialization: def _verify_all_complex_types_converted(self, data): """Helper method to verify all complex types were properly converted""" - if isinstance(data, dict): - for key, value in data.items(): - if key in ["id", "checksum"]: - # These should be strings (UUID/bytes converted) - assert isinstance(value, str) - elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]: - # These should be strings (datetime converted) - assert isinstance(value, str) - elif key in ["total_time", "duration"]: - # These should be floats (Decimal converted) - assert isinstance(value, float) - elif key == "metrics": - # This should be a list (ndarray converted) - assert isinstance(value, list) - else: - # Recursively check nested structures - self._verify_all_complex_types_converted(value) - elif isinstance(data, list): - for item in data: - self._verify_all_complex_types_converted(item) + match data: + case dict(): + for key, value in data.items(): + if key in ["id", "checksum"]: + # These should be strings (UUID/bytes converted) + assert isinstance(value, str) + elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]: + # These should be strings (datetime converted) + assert isinstance(value, str) + elif key in ["total_time", "duration"]: + # These should be floats (Decimal converted) + assert isinstance(value, float) + elif key == "metrics": + # This should be a list (ndarray converted) + assert isinstance(value, list) + else: + # Recursively check nested structures + self._verify_all_complex_types_converted(value) + case list(): + for item in data: + self._verify_all_complex_types_converted(item) diff --git a/api/tests/unit_tests/services/test_model_provider_service.py b/api/tests/unit_tests/services/test_model_provider_service.py index 806be013497..a8a976f4b07 100644 --- a/api/tests/unit_tests/services/test_model_provider_service.py +++ b/api/tests/unit_tests/services/test_model_provider_service.py @@ -215,12 +215,13 @@ class TestModelProviderServiceDelegation: get_provider_config_mock.assert_called_once_with("tenant-1", "openai") provider_method = getattr(provider_configuration, provider_method_name) - if isinstance(provider_call_kwargs, tuple): - provider_method.assert_called_once_with(*provider_call_kwargs) - elif isinstance(provider_call_kwargs, dict): - provider_method.assert_called_once_with(**provider_call_kwargs) - else: - provider_method.assert_called_once_with(provider_call_kwargs) + match provider_call_kwargs: + case tuple(): + provider_method.assert_called_once_with(*provider_call_kwargs) + case dict(): + provider_method.assert_called_once_with(**provider_call_kwargs) + case _: + provider_method.assert_called_once_with(provider_call_kwargs) if method_name == "get_provider_credential": assert result == {"token": "abc"}