mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
chore: port isinstance to match case (#37271)
Co-authored-by: WH-2099 <wh2099@pm.me> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
f0b34bdeb4
commit
af99414fc1
@ -42,10 +42,11 @@ def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
"""Serialize default values into strings expected by human-input form clients."""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
match value:
|
||||
case None:
|
||||
result[key] = ""
|
||||
case dict() | list():
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
case _:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
@ -182,14 +182,15 @@ register_response_schema_models(
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
match raw_outputs:
|
||||
case _ if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
case dict():
|
||||
outputs = raw_outputs
|
||||
case _ if isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
case _:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
|
||||
@ -231,22 +231,23 @@ class AppRunner:
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
match invoke_result:
|
||||
case LLMResult() if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
case _ if stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(
|
||||
self,
|
||||
|
||||
@ -882,7 +882,7 @@ class WorkflowResponseConverter:
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
def _get_file_var_from_value(cls, value: object) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@ -891,10 +891,11 @@ class WorkflowResponseConverter:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
match value:
|
||||
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
case File():
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -144,15 +144,16 @@ def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str,
|
||||
Returns an empty dict if the context is missing or incomplete.
|
||||
"""
|
||||
parent_trace_context = args.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
context = parent_trace_context
|
||||
elif isinstance(parent_trace_context, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
match parent_trace_context:
|
||||
case ParentTraceContext():
|
||||
context = parent_trace_context
|
||||
case Mapping():
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
return {}
|
||||
case _:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
if context.parent_node_execution_id is None:
|
||||
return {}
|
||||
|
||||
@ -116,20 +116,21 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case PluginParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
match value:
|
||||
case None:
|
||||
return False
|
||||
case str():
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
case _:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case PluginParameterType.NUMBER:
|
||||
match value:
|
||||
|
||||
@ -304,22 +304,23 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
# Check if this dict has a $ref field
|
||||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
match schema:
|
||||
case dict():
|
||||
# Check if this dict has a $ref field
|
||||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
elif isinstance(schema, list):
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
return True
|
||||
|
||||
case list():
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
|
||||
# Primitive types don't contain refs
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, override
|
||||
from datetime import datetime, tzinfo
|
||||
from typing import Any, cast, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -35,17 +35,26 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz: str | tzinfo | None = None) -> int | None:
|
||||
try:
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
converted_localtime: datetime
|
||||
match local_tz:
|
||||
case None:
|
||||
converted_localtime = local_time.astimezone()
|
||||
case str() as timezone_name:
|
||||
timezone = pytz.timezone(timezone_name)
|
||||
converted_localtime = timezone.localize(local_time)
|
||||
case tzinfo():
|
||||
localize = getattr(local_tz, "localize", None)
|
||||
if callable(localize):
|
||||
converted_localtime = cast(datetime, localize(local_time))
|
||||
else:
|
||||
converted_localtime = local_time.replace(tzinfo=local_tz)
|
||||
case _:
|
||||
raise ValueError("local_tz must be None, a timezone name, or a tzinfo instance")
|
||||
timestamp = int(converted_localtime.timestamp())
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
@ -122,13 +122,14 @@ class MCPTool(Tool):
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
yield from self._process_json_list(content_json)
|
||||
else:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
match content_json:
|
||||
case dict():
|
||||
yield self.create_json_message(content_json)
|
||||
case list():
|
||||
yield from self._process_json_list(content_json)
|
||||
case _:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
|
||||
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process a list of JSON items."""
|
||||
@ -222,16 +223,17 @@ class MCPTool(Tool):
|
||||
|
||||
# Recursively search through nested structures
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
match value:
|
||||
case _ if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
case list() if not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
|
||||
@ -196,16 +196,17 @@ class WorkflowTool(Tool):
|
||||
return usage_candidate
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
match value:
|
||||
case _ if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
case _ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
@ -393,24 +394,25 @@ class WorkflowTool(Tool):
|
||||
files: list[File] = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
match value:
|
||||
case list():
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
case dict() if value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
|
||||
@ -297,25 +297,26 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
match value:
|
||||
case str():
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
else:
|
||||
resolved_value = value
|
||||
resolved_value = segment_group.text
|
||||
case _ if isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
case _:
|
||||
resolved_value = value
|
||||
resolved_conditions.append(
|
||||
Condition(
|
||||
name=cond.name,
|
||||
|
||||
@ -354,11 +354,12 @@ def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
|
||||
|
||||
|
||||
def target_names(target: ast.AST) -> Iterable[str]:
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, ast.Tuple | ast.List):
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
match target:
|
||||
case ast.Name():
|
||||
yield target.id
|
||||
case ast.Tuple() | ast.List():
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
|
||||
|
||||
def record_assignment(
|
||||
|
||||
@ -94,12 +94,13 @@ class RedisSubscriptionBase(Subscription):
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
match channel_field:
|
||||
case bytes():
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
case str():
|
||||
channel_name = channel_field
|
||||
case _:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
|
||||
@ -88,22 +88,23 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
if isinstance(self._client, RedisCluster):
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||
# specifying the `target_node` argument would use busy-looping to wait
|
||||
# for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
elif isinstance(self._client, Redis):
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
match self._client:
|
||||
case RedisCluster():
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||
# specifying the `target_node` argument would use busy-looping to wait
|
||||
# for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
case Redis():
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
case _:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
|
||||
@override
|
||||
def _get_message_type(self) -> str:
|
||||
|
||||
@ -138,10 +138,11 @@ class _StreamsSubscription(Subscription):
|
||||
if isinstance(fields, dict):
|
||||
data = fields.get(b"data")
|
||||
data_bytes: bytes | None = None
|
||||
if isinstance(data, str):
|
||||
data_bytes = data.encode()
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
data_bytes = bytes(data)
|
||||
match data:
|
||||
case str():
|
||||
data_bytes = data.encode()
|
||||
case bytes() | bytearray():
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
self._queue.put_nowait(data_bytes)
|
||||
last_id = entry_id
|
||||
|
||||
@ -1174,34 +1174,32 @@ class Conversation(Base):
|
||||
|
||||
# Convert file mapping to File object
|
||||
for key, value in inputs.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
match value:
|
||||
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
case list():
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
)
|
||||
)
|
||||
inputs[key] = file_list
|
||||
inputs[key] = file_list
|
||||
|
||||
return inputs
|
||||
|
||||
@ -1516,46 +1514,45 @@ class Message(Base):
|
||||
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
|
||||
)
|
||||
for key, value in inputs.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
match value:
|
||||
case dict() if cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
inputs[key] = build_file_from_input_mapping(
|
||||
file_mapping=value_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
case list():
|
||||
value_list = value
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
file_list.append(
|
||||
build_file_from_input_mapping(
|
||||
file_mapping=item_dict,
|
||||
tenant_resolver=tenant_resolver,
|
||||
)
|
||||
)
|
||||
)
|
||||
inputs[key] = file_list
|
||||
inputs[key] = file_list
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list):
|
||||
v_list = v
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
match v:
|
||||
case File():
|
||||
inputs[k] = v.model_dump()
|
||||
case list():
|
||||
v_list = v
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
|
||||
@ -96,12 +96,13 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
|
||||
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self._model_class):
|
||||
model = value
|
||||
elif isinstance(value, str):
|
||||
model = self._model_class.model_validate_json(value)
|
||||
else:
|
||||
model = self._model_class.model_validate(value)
|
||||
match value:
|
||||
case _ if isinstance(value, self._model_class):
|
||||
model = value
|
||||
case str():
|
||||
model = self._model_class.model_validate_json(value)
|
||||
case _:
|
||||
model = self._model_class.model_validate(value)
|
||||
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
@override
|
||||
|
||||
@ -1150,13 +1150,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
try:
|
||||
# Convert outputs to string based on type
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
match trace_info.outputs:
|
||||
case dict() | list():
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
case str():
|
||||
outputs_str = trace_info.outputs
|
||||
case _:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
@ -1553,25 +1554,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
|
||||
|
||||
# Handle list of messages
|
||||
if isinstance(prompts, list):
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
match prompts:
|
||||
case list():
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
|
||||
# Handle single dict or plain string prompt
|
||||
elif isinstance(prompts, (dict, str)):
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
# Handle single dict or plain string prompt
|
||||
case dict() | str():
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
|
||||
return attributes
|
||||
|
||||
@ -18,24 +18,25 @@ def validate_input_output(v, field_name):
|
||||
"""
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
elif isinstance(v, list):
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
match v:
|
||||
case str():
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
case list():
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
}
|
||||
]
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@ -64,40 +64,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
match field_name:
|
||||
case "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match v:
|
||||
case str():
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -107,16 +87,37 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case list():
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
|
||||
@ -40,41 +40,19 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match v:
|
||||
case str():
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -82,16 +60,39 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case list():
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
|
||||
@ -361,12 +361,13 @@ class ClickzettaVector(BaseVector):
|
||||
first_pass = json.loads(raw_metadata)
|
||||
|
||||
# Handle double-encoded JSON
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
match first_pass:
|
||||
case str():
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
case dict():
|
||||
metadata = first_pass
|
||||
case _:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
@ -942,12 +943,13 @@ class ClickzettaVector(BaseVector):
|
||||
# First parse may yield a string (double-encoded JSON)
|
||||
first_pass = json.loads(row[2])
|
||||
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
match first_pass:
|
||||
case str():
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
case dict():
|
||||
metadata = first_pass
|
||||
case _:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
|
||||
@ -98,14 +98,15 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
values: list[Any] = []
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
if isinstance(part, psql.Identifier):
|
||||
values.append(("ident", part._obj[0] if part._obj else ""))
|
||||
elif isinstance(part, psql.Literal):
|
||||
values.append(("literal", part._obj))
|
||||
elif isinstance(part, psql.Composed):
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
match part:
|
||||
case psql.Identifier():
|
||||
values.append(("ident", part._obj[0] if part._obj else ""))
|
||||
case psql.Literal():
|
||||
values.append(("literal", part._obj))
|
||||
case psql.Composed():
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
return values
|
||||
|
||||
|
||||
|
||||
@ -957,21 +957,22 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
pause_reason_models = []
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
elif isinstance(reason, SchedulingPause):
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"unkown reason type: {type(reason)}")
|
||||
match reason:
|
||||
case HumanInputRequired():
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
case SchedulingPause():
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
case _:
|
||||
raise AssertionError(f"unknown reason type: {type(reason)}")
|
||||
|
||||
pause_reason_models.append(pause_reason_model)
|
||||
|
||||
|
||||
@ -334,16 +334,17 @@ class ComposerConfigValidator:
|
||||
|
||||
@classmethod
|
||||
def _reject_plaintext_secrets(cls, value: Any, *, path: str) -> None:
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
normalized_key = key.lower().replace("-", "_")
|
||||
nested_path = f"{path}.{key}"
|
||||
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
|
||||
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
|
||||
cls._reject_plaintext_secrets(nested, path=nested_path)
|
||||
elif isinstance(value, list):
|
||||
for index, nested in enumerate(value):
|
||||
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
|
||||
match value:
|
||||
case dict():
|
||||
for key, nested in value.items():
|
||||
normalized_key = key.lower().replace("-", "_")
|
||||
nested_path = f"{path}.{key}"
|
||||
if normalized_key in _PLAINTEXT_SECRET_KEYS and isinstance(nested, str) and nested:
|
||||
raise PlaintextSecretNotAllowedError(f"Plaintext secret is not allowed at {nested_path}")
|
||||
cls._reject_plaintext_secrets(nested, path=nested_path)
|
||||
case list():
|
||||
for index, nested in enumerate(value):
|
||||
cls._reject_plaintext_secrets(nested, path=f"{path}[{index}]")
|
||||
|
||||
@classmethod
|
||||
def _has_install_command(cls, entry: dict[str, Any]) -> bool:
|
||||
|
||||
@ -14,12 +14,13 @@ class AttachmentService:
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
case sessionmaker():
|
||||
self._session_maker = session_factory
|
||||
case _:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
|
||||
@ -39,12 +39,13 @@ class FileService:
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
case sessionmaker():
|
||||
self._session_maker = session_factory
|
||||
case _:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
|
||||
@ -119,10 +119,11 @@ class HumanInputDeliveryTestService:
|
||||
|
||||
class EmailDeliveryTestHandler:
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
match session_factory:
|
||||
case None:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
case Engine():
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def supports(self, method: DeliveryChannelConfig) -> bool:
|
||||
@ -179,11 +180,12 @@ class EmailDeliveryTestHandler:
|
||||
emails: list[str] = []
|
||||
bound_reference_ids: list[str] = []
|
||||
for recipient in recipients.items:
|
||||
if isinstance(recipient, MemberRecipient):
|
||||
bound_reference_ids.append(recipient.reference_id)
|
||||
elif isinstance(recipient, ExternalRecipient):
|
||||
if recipient.email:
|
||||
emails.append(recipient.email)
|
||||
match recipient:
|
||||
case MemberRecipient():
|
||||
bound_reference_ids.append(recipient.reference_id)
|
||||
case ExternalRecipient():
|
||||
if recipient.email:
|
||||
emails.append(recipient.email)
|
||||
|
||||
if recipients.include_bound_group:
|
||||
bound_reference_ids = []
|
||||
|
||||
@ -125,10 +125,11 @@ class DraftVarLoader(VariableLoader):
|
||||
# can be safely accessed before any offloading logic is applied.
|
||||
for draft_var in draft_vars:
|
||||
value = draft_var.get_value()
|
||||
if isinstance(value, FileSegment):
|
||||
files.append(value.value)
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend(value.value)
|
||||
match value:
|
||||
case FileSegment():
|
||||
files.append(value.value)
|
||||
case ArrayFileSegment():
|
||||
files.extend(value.value)
|
||||
with Session(bind=self._engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(
|
||||
session,
|
||||
|
||||
@ -34,10 +34,11 @@ class WorkflowRunService:
|
||||
|
||||
def __init__(self, session_factory: Engine | sessionmaker | None = None):
|
||||
"""Initialize WorkflowRunService with repository dependencies."""
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
match session_factory:
|
||||
case None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
case Engine():
|
||||
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
|
||||
self._session_factory = session_factory
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
|
||||
@ -131,10 +131,11 @@ def fetch_table_rows(
|
||||
for row in rows:
|
||||
normalized = dict(row)
|
||||
for key, value in normalized.items():
|
||||
if isinstance(value, datetime):
|
||||
normalized[key] = value.isoformat()
|
||||
elif isinstance(value, uuid.UUID):
|
||||
normalized[key] = str(value)
|
||||
match value:
|
||||
case datetime():
|
||||
normalized[key] = value.isoformat()
|
||||
case uuid.UUID():
|
||||
normalized[key] = str(value)
|
||||
result.append(normalized)
|
||||
return result
|
||||
|
||||
|
||||
@ -8,12 +8,13 @@ from pathlib import Path
|
||||
|
||||
def _walk_values(value):
|
||||
yield value
|
||||
if isinstance(value, dict):
|
||||
for child in value.values():
|
||||
yield from _walk_values(child)
|
||||
elif isinstance(value, list):
|
||||
for child in value:
|
||||
yield from _walk_values(child)
|
||||
match value:
|
||||
case dict():
|
||||
for child in value.values():
|
||||
yield from _walk_values(child)
|
||||
case list():
|
||||
for child in value:
|
||||
yield from _walk_values(child)
|
||||
|
||||
|
||||
def _load_generate_swagger_specs_module():
|
||||
|
||||
@ -4,6 +4,7 @@ import calendar
|
||||
import math
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
@ -66,6 +67,20 @@ def test_localtime_to_timestamp_tool():
|
||||
ts_value = float(ts_message.strip())
|
||||
assert math.isfinite(ts_value)
|
||||
assert ts_value >= 0
|
||||
assert (
|
||||
LocaltimeToTimestampTool.localtime_to_timestamp(
|
||||
"2024-01-01 10:00:00",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
ZoneInfo("UTC"),
|
||||
)
|
||||
== 1704103200
|
||||
)
|
||||
with pytest.raises(ToolInvokeError, match="local_tz must be"):
|
||||
LocaltimeToTimestampTool.localtime_to_timestamp(
|
||||
"2024-01-01 10:00:00",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
object(), # type: ignore[arg-type]
|
||||
)
|
||||
with pytest.raises(ToolInvokeError):
|
||||
LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC")
|
||||
|
||||
|
||||
@ -459,23 +459,24 @@ class TestEndToEndSerialization:
|
||||
|
||||
def _verify_all_complex_types_converted(self, data):
|
||||
"""Helper method to verify all complex types were properly converted"""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if key in ["id", "checksum"]:
|
||||
# These should be strings (UUID/bytes converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
|
||||
# These should be strings (datetime converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["total_time", "duration"]:
|
||||
# These should be floats (Decimal converted)
|
||||
assert isinstance(value, float)
|
||||
elif key == "metrics":
|
||||
# This should be a list (ndarray converted)
|
||||
assert isinstance(value, list)
|
||||
else:
|
||||
# Recursively check nested structures
|
||||
self._verify_all_complex_types_converted(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._verify_all_complex_types_converted(item)
|
||||
match data:
|
||||
case dict():
|
||||
for key, value in data.items():
|
||||
if key in ["id", "checksum"]:
|
||||
# These should be strings (UUID/bytes converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
|
||||
# These should be strings (datetime converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["total_time", "duration"]:
|
||||
# These should be floats (Decimal converted)
|
||||
assert isinstance(value, float)
|
||||
elif key == "metrics":
|
||||
# This should be a list (ndarray converted)
|
||||
assert isinstance(value, list)
|
||||
else:
|
||||
# Recursively check nested structures
|
||||
self._verify_all_complex_types_converted(value)
|
||||
case list():
|
||||
for item in data:
|
||||
self._verify_all_complex_types_converted(item)
|
||||
|
||||
@ -215,12 +215,13 @@ class TestModelProviderServiceDelegation:
|
||||
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
match provider_call_kwargs:
|
||||
case tuple():
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
case dict():
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
case _:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user