chore: port isinstance to match case (#37271)

Co-authored-by: WH-2099 <wh2099@pm.me>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-18 09:18:03 +09:00 committed by GitHub
parent f0b34bdeb4
commit af99414fc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 526 additions and 468 deletions

View File

@ -42,10 +42,11 @@ def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
"""Serialize default values into strings expected by human-input form clients."""
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
match value:
case None:
result[key] = ""
case dict() | list():
result[key] = json.dumps(value, ensure_ascii=False)
case _:
result[key] = str(value)
return result

View File

@ -182,14 +182,15 @@ register_response_schema_models(
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
match raw_outputs:
case _ if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
case dict():
outputs = raw_outputs
case _ if isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
case _:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,

View File

@ -231,22 +231,23 @@ class AppRunner:
:param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
match invoke_result:
case LLMResult() if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
case _ if stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
case _:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self,

View File

@ -882,7 +882,7 @@ class WorkflowResponseConverter:
return files
@classmethod
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
def _get_file_var_from_value(cls, value: object) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
@ -891,10 +891,11 @@ class WorkflowResponseConverter:
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
match value:
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
case File():
return value.to_dict()
return None

View File

@ -144,15 +144,16 @@ def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str,
Returns an empty dict if the context is missing or incomplete.
"""
parent_trace_context = args.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
context = parent_trace_context
elif isinstance(parent_trace_context, Mapping):
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
match parent_trace_context:
case ParentTraceContext():
context = parent_trace_context
case Mapping():
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
return {}
case _:
return {}
else:
return {}
if context.parent_node_execution_id is None:
return {}

View File

@ -116,20 +116,21 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
return value if isinstance(value, str) else str(value)
case PluginParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
match value:
case None:
return False
case str():
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
case _:
return value if isinstance(value, bool) else bool(value)
case PluginParameterType.NUMBER:
match value:

View File

@ -304,22 +304,23 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
Returns:
True if any Dify $ref is found, False otherwise
"""
if isinstance(schema, dict):
# Check if this dict has a $ref field
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
match schema:
case dict():
# Check if this dict has a $ref field
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
elif isinstance(schema, list):
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
return True
case list():
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Primitive types don't contain refs
return False

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any, override
from datetime import datetime, tzinfo
from typing import Any, cast, override
import pytz # type: ignore[import-untyped]
@ -35,17 +35,26 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
def localtime_to_timestamp(localtime: str, time_format: str, local_tz: str | tzinfo | None = None) -> int | None:
try:
local_time = datetime.strptime(localtime, time_format)
if local_tz is None:
localtime = local_time.astimezone() # type: ignore
elif isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
converted_localtime: datetime
match local_tz:
case None:
converted_localtime = local_time.astimezone()
case str() as timezone_name:
timezone = pytz.timezone(timezone_name)
converted_localtime = timezone.localize(local_time)
case tzinfo():
localize = getattr(local_tz, "localize", None)
if callable(localize):
converted_localtime = cast(datetime, localize(local_time))
else:
converted_localtime = local_time.replace(tzinfo=local_tz)
case _:
raise ValueError("local_tz must be None, a timezone name, or a tzinfo instance")
timestamp = int(converted_localtime.timestamp())
return timestamp
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -122,13 +122,14 @@ class MCPTool(Tool):
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
if isinstance(content_json, dict):
yield self.create_json_message(content_json)
elif isinstance(content_json, list):
yield from self._process_json_list(content_json)
else:
# For primitive types (str, int, bool, etc.), convert to string
yield self.create_text_message(str(content_json))
match content_json:
case dict():
yield self.create_json_message(content_json)
case list():
yield from self._process_json_list(content_json)
case _:
# For primitive types (str, int, bool, etc.), convert to string
yield self.create_text_message(str(content_json))
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
"""Process a list of JSON items."""
@ -222,16 +223,17 @@ class MCPTool(Tool):
# Recursively search through nested structures
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
match value:
case _ if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
case list() if not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
@override

View File

@ -196,16 +196,17 @@ class WorkflowTool(Tool):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
match value:
case _ if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
case _ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
@override
@ -393,24 +394,25 @@ class WorkflowTool(Tool):
files: list[File] = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
match value:
case list():
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
access_controller=_file_access_controller,
)
files.append(file)
result[key] = value

View File

@ -297,25 +297,26 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
match value:
case str():
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_value = segment_group.text
case _ if isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
case _:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,

View File

@ -354,11 +354,12 @@ def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
def target_names(target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, ast.Tuple | ast.List):
for item in target.elts:
yield from target_names(item)
match target:
case ast.Name():
yield target.id
case ast.Tuple() | ast.List():
for item in target.elts:
yield from target_names(item)
def record_assignment(

View File

@ -94,12 +94,13 @@ class RedisSubscriptionBase(Subscription):
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
match channel_field:
case bytes():
channel_name = channel_field.decode("utf-8")
case str():
channel_name = channel_field
case _:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(

View File

@ -88,22 +88,23 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
if isinstance(self._client, RedisCluster):
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
# specifying the `target_node` argument would use busy-looping to wait
# for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
elif isinstance(self._client, Redis):
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
else:
raise AssertionError("client should be either Redis or RedisCluster.")
match self._client:
case RedisCluster():
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
# specifying the `target_node` argument would use busy-looping to wait
# for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
case Redis():
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
case _:
raise AssertionError("client should be either Redis or RedisCluster.")
@override
def _get_message_type(self) -> str:

View File

@ -138,10 +138,11 @@ class _StreamsSubscription(Subscription):
if isinstance(fields, dict):
data = fields.get(b"data")
data_bytes: bytes | None = None
if isinstance(data, str):
data_bytes = data.encode()
elif isinstance(data, (bytes, bytearray)):
data_bytes = bytes(data)
match data:
case str():
data_bytes = data.encode()
case bytes() | bytearray():
data_bytes = bytes(data)
if data_bytes is not None:
self._queue.put_nowait(data_bytes)
last_id = entry_id

View File

@ -1174,34 +1174,32 @@ class Conversation(Base):
# Convert file mapping to File object
for key, value in inputs.items():
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
match value:
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
case list():
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
)
inputs[key] = file_list
inputs[key] = file_list
return inputs
@ -1516,46 +1514,45 @@ class Message(Base):
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
)
for key, value in inputs.items():
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
match value:
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
value_dict = cast(dict[str, Any], value)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
case list():
value_list = value
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
)
inputs[key] = file_list
inputs[key] = file_list
return inputs
@inputs.setter
def inputs(self, value: Mapping[str, Any]):
inputs = dict(value)
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list):
v_list = v
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
match v:
case File():
inputs[k] = v.model_dump()
case list():
v_list = v
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property

View File

@ -96,12 +96,13 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
if value is None:
return None
if isinstance(value, self._model_class):
model = value
elif isinstance(value, str):
model = self._model_class.model_validate_json(value)
else:
model = self._model_class.model_validate(value)
match value:
case _ if isinstance(value, self._model_class):
model = value
case str():
model = self._model_class.model_validate_json(value)
case _:
model = self._model_class.model_validate(value)
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
@override

View File

@ -1150,13 +1150,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
# Convert outputs to string based on type
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
if isinstance(trace_info.outputs, dict | list):
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
match trace_info.outputs:
case dict() | list():
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
case str():
outputs_str = trace_info.outputs
case _:
outputs_str = str(trace_info.outputs)
llm_attributes: dict[str, Any] = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
@ -1553,25 +1554,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
# Handle list of messages
if isinstance(prompts, list):
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
match prompts:
case list():
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
# Handle single dict or plain string prompt
elif isinstance(prompts, (dict, str)):
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
# Handle single dict or plain string prompt
case dict() | str():
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
return attributes

View File

@ -18,24 +18,25 @@ def validate_input_output(v, field_name):
"""
if v == {} or v is None:
return v
if isinstance(v, str):
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": v,
}
]
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
v = replace_text_with_content(data=v)
return v
else:
match v:
case str():
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": str(v),
"content": v,
}
]
case list():
if len(v) > 0 and isinstance(v[0], dict):
v = replace_text_with_content(data=v)
return v
else:
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": str(v),
}
]
return v

View File

@ -64,40 +64,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
match field_name:
case "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case _:
pass
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match v:
case str():
match field_name:
case "inputs":
data = {
"messages": v,
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case "outputs":
data = {
return {
"choices": {
"role": "ai",
"content": v,
@ -107,16 +87,37 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
}
case _:
pass
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case list():
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match field_name:
case "inputs":
data = {
"messages": v,
}
case "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case _:
pass
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list

View File

@ -40,41 +40,19 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match v:
case str():
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
data = {
return {
"choices": {
"role": "ai",
"content": v,
@ -82,16 +60,39 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case list():
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list

View File

@ -361,12 +361,13 @@ class ClickzettaVector(BaseVector):
first_pass = json.loads(raw_metadata)
# Handle double-encoded JSON
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
match first_pass:
case str():
metadata = parse_metadata_json(first_pass)
case dict():
metadata = first_pass
case _:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):
@ -942,12 +943,13 @@ class ClickzettaVector(BaseVector):
# First parse may yield a string (double-encoded JSON)
first_pass = json.loads(row[2])
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
match first_pass:
case str():
metadata = parse_metadata_json(first_pass)
case dict():
metadata = first_pass
case _:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):

View File

@ -98,14 +98,15 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
values: list[Any] = []
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.Identifier):
values.append(("ident", part._obj[0] if part._obj else ""))
elif isinstance(part, psql.Literal):
values.append(("literal", part._obj))
elif isinstance(part, psql.Composed):
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
match part:
case psql.Identifier():
values.append(("ident", part._obj[0] if part._obj else ""))
case psql.Literal():
values.append(("literal", part._obj))
case psql.Composed():
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
return values

View File

@ -957,21 +957,22 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
)
pause_reason_models = []
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
elif isinstance(reason, SchedulingPause):
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
else:
raise AssertionError(f"unkown reason type: {type(reason)}")
match reason:
case HumanInputRequired():
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
case SchedulingPause():
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
case _:
raise AssertionError(f"unknown reason type: {type(reason)}")
pause_reason_models.append(pause_reason_model)

View File

@ -334,16 +334,17 @@ class ComposerConfigValidator:
@classmethod
def _reject_plaintext_secrets(cls, value: Any, *, path: str) -> None:
if isinstance(value, dict):
for key, nested in value.items():
normalized_key = key.lower().replace("-", "_")
nested_path = f"{path}.{key}"
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
cls._reject_plaintext_secrets(nested, path=nested_path)
elif isinstance(value, list):
for index, nested in enumerate(value):
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
match value:
case dict():
for key, nested in value.items():
normalized_key = key.lower().replace("-", "_")
nested_path = f"{path}.{key}"
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
cls._reject_plaintext_secrets(nested, path=nested_path)
case list():
for index, nested in enumerate(value):
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
@classmethod
def _has_install_command(cls, entry: dict[str, Any]) -> bool:

View File

@ -14,12 +14,13 @@ class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
match session_factory:
case Engine():
self._session_maker = sessionmaker(bind=session_factory)
case sessionmaker():
self._session_maker = session_factory
case _:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:

View File

@ -39,12 +39,13 @@ class FileService:
_session_maker: sessionmaker[Session]
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
match session_factory:
case Engine():
self._session_maker = sessionmaker(bind=session_factory)
case sessionmaker():
self._session_maker = session_factory
case _:
raise AssertionError("must be a sessionmaker or an Engine.")
def upload_file(
self,

View File

@ -119,10 +119,11 @@ class HumanInputDeliveryTestService:
class EmailDeliveryTestHandler:
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
if session_factory is None:
session_factory = sessionmaker(bind=db.engine)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
match session_factory:
case None:
session_factory = sessionmaker(bind=db.engine)
case Engine():
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def supports(self, method: DeliveryChannelConfig) -> bool:
@ -179,11 +180,12 @@ class EmailDeliveryTestHandler:
emails: list[str] = []
bound_reference_ids: list[str] = []
for recipient in recipients.items:
if isinstance(recipient, MemberRecipient):
bound_reference_ids.append(recipient.reference_id)
elif isinstance(recipient, ExternalRecipient):
if recipient.email:
emails.append(recipient.email)
match recipient:
case MemberRecipient():
bound_reference_ids.append(recipient.reference_id)
case ExternalRecipient():
if recipient.email:
emails.append(recipient.email)
if recipients.include_bound_group:
bound_reference_ids = []

View File

@ -125,10 +125,11 @@ class DraftVarLoader(VariableLoader):
# can be safely accessed before any offloading logic is applied.
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
match value:
case FileSegment():
files.append(value.value)
case ArrayFileSegment():
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(
session,

View File

@ -34,10 +34,11 @@ class WorkflowRunService:
def __init__(self, session_factory: Engine | sessionmaker | None = None):
"""Initialize WorkflowRunService with repository dependencies."""
if session_factory is None:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
match session_factory:
case None:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
case Engine():
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
self._session_factory = session_factory
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(

View File

@ -131,10 +131,11 @@ def fetch_table_rows(
for row in rows:
normalized = dict(row)
for key, value in normalized.items():
if isinstance(value, datetime):
normalized[key] = value.isoformat()
elif isinstance(value, uuid.UUID):
normalized[key] = str(value)
match value:
case datetime():
normalized[key] = value.isoformat()
case uuid.UUID():
normalized[key] = str(value)
result.append(normalized)
return result

View File

@ -8,12 +8,13 @@ from pathlib import Path
def _walk_values(value):
yield value
if isinstance(value, dict):
for child in value.values():
yield from _walk_values(child)
elif isinstance(value, list):
for child in value:
yield from _walk_values(child)
match value:
case dict():
for child in value.values():
yield from _walk_values(child)
case list():
for child in value:
yield from _walk_values(child)
def _load_generate_swagger_specs_module():

View File

@ -4,6 +4,7 @@ import calendar
import math
from datetime import date
from types import SimpleNamespace
from zoneinfo import ZoneInfo
import pytest
@ -66,6 +67,20 @@ def test_localtime_to_timestamp_tool():
ts_value = float(ts_message.strip())
assert math.isfinite(ts_value)
assert ts_value >= 0
assert (
LocaltimeToTimestampTool.localtime_to_timestamp(
"2024-01-01 10:00:00",
"%Y-%m-%d %H:%M:%S",
ZoneInfo("UTC"),
)
== 1704103200
)
with pytest.raises(ToolInvokeError, match="local_tz must be"):
LocaltimeToTimestampTool.localtime_to_timestamp(
"2024-01-01 10:00:00",
"%Y-%m-%d %H:%M:%S",
object(), # type: ignore[arg-type]
)
with pytest.raises(ToolInvokeError):
LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC")

View File

@ -459,23 +459,24 @@ class TestEndToEndSerialization:
def _verify_all_complex_types_converted(self, data):
"""Helper method to verify all complex types were properly converted"""
if isinstance(data, dict):
for key, value in data.items():
if key in ["id", "checksum"]:
# These should be strings (UUID/bytes converted)
assert isinstance(value, str)
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
# These should be strings (datetime converted)
assert isinstance(value, str)
elif key in ["total_time", "duration"]:
# These should be floats (Decimal converted)
assert isinstance(value, float)
elif key == "metrics":
# This should be a list (ndarray converted)
assert isinstance(value, list)
else:
# Recursively check nested structures
self._verify_all_complex_types_converted(value)
elif isinstance(data, list):
for item in data:
self._verify_all_complex_types_converted(item)
match data:
case dict():
for key, value in data.items():
if key in ["id", "checksum"]:
# These should be strings (UUID/bytes converted)
assert isinstance(value, str)
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
# These should be strings (datetime converted)
assert isinstance(value, str)
elif key in ["total_time", "duration"]:
# These should be floats (Decimal converted)
assert isinstance(value, float)
elif key == "metrics":
# This should be a list (ndarray converted)
assert isinstance(value, list)
else:
# Recursively check nested structures
self._verify_all_complex_types_converted(value)
case list():
for item in data:
self._verify_all_complex_types_converted(item)

View File

@ -215,12 +215,13 @@ class TestModelProviderServiceDelegation:
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
match provider_call_kwargs:
case tuple():
provider_method.assert_called_once_with(*provider_call_kwargs)
case dict():
provider_method.assert_called_once_with(**provider_call_kwargs)
case _:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}