mirror of
https://github.com/langgenius/dify.git
synced 2026-06-12 11:32:08 +08:00
This commit is contained in:
parent
7ec295fd66
commit
99351d2f98
@ -208,12 +208,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
match scratchpad.action.action_input:
|
||||
case dict():
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
case str():
|
||||
final_answer = scratchpad.action.action_input
|
||||
case _:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
|
||||
@ -43,13 +43,14 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
case UserPromptMessage():
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
case AssistantPromptMessage():
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
match message.content:
|
||||
case str():
|
||||
historic_prompt += message.content + "\n\n"
|
||||
case list():
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
|
||||
@ -301,31 +301,32 @@ class AppRunner:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
match content:
|
||||
case str():
|
||||
text += content
|
||||
case TextPromptMessageContent():
|
||||
text += content.data
|
||||
case ImagePromptMessageContent():
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
case _:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
match message.content:
|
||||
case str():
|
||||
text += message.content
|
||||
case list():
|
||||
for content in message.content:
|
||||
match content:
|
||||
case str():
|
||||
text += content
|
||||
case TextPromptMessageContent():
|
||||
text += content.data
|
||||
case ImagePromptMessageContent():
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
case _:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
@ -166,12 +166,13 @@ def invoke_llm_with_structured_output(
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
match event.delta.message.content:
|
||||
case str():
|
||||
result_text += event.delta.message.content
|
||||
case list():
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
|
||||
@ -211,12 +211,13 @@ class SSETransport:
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status.endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status.exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
match status:
|
||||
case _StatusReady():
|
||||
return status.endpoint_url
|
||||
case _StatusError():
|
||||
raise status.exc
|
||||
case _:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
|
||||
@ -41,10 +41,11 @@ class MessageHandlerFnT(Protocol):
|
||||
def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
):
|
||||
if isinstance(message, Exception):
|
||||
raise ValueError(str(message))
|
||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||
pass
|
||||
match message:
|
||||
case Exception():
|
||||
raise ValueError(str(message))
|
||||
case types.ServerNotification() | RequestResponder():
|
||||
pass
|
||||
|
||||
|
||||
def _default_sampling_callback(
|
||||
|
||||
@ -62,15 +62,16 @@ class BaseTraceInfo(BaseModel):
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if isinstance(parent_ctx, ParentTraceContext):
|
||||
context = parent_ctx
|
||||
elif isinstance(parent_ctx, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_ctx)
|
||||
except ValueError:
|
||||
match parent_ctx:
|
||||
case ParentTraceContext():
|
||||
context = parent_ctx
|
||||
case Mapping():
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_ctx)
|
||||
except ValueError:
|
||||
return None, None
|
||||
case _:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
return (
|
||||
context.parent_workflow_run_id,
|
||||
context.parent_node_execution_id,
|
||||
|
||||
@ -38,18 +38,19 @@ def measure_time():
|
||||
|
||||
|
||||
def replace_text_with_content(data):
|
||||
if isinstance(data, dict):
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if key == "text":
|
||||
new_data["content"] = value
|
||||
else:
|
||||
new_data[key] = replace_text_with_content(value)
|
||||
return new_data
|
||||
elif isinstance(data, list):
|
||||
return [replace_text_with_content(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
match data:
|
||||
case dict():
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if key == "text":
|
||||
new_data["content"] = value
|
||||
else:
|
||||
new_data[key] = replace_text_with_content(value)
|
||||
return new_data
|
||||
case list():
|
||||
return [replace_text_with_content(item) for item in data]
|
||||
case _:
|
||||
return data
|
||||
|
||||
|
||||
def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str:
|
||||
|
||||
@ -45,12 +45,13 @@ _plugin_daemon_timeout_config = cast(
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
plugin_daemon_request_timeout = None
|
||||
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||
else:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
match _plugin_daemon_timeout_config:
|
||||
case None:
|
||||
plugin_daemon_request_timeout = None
|
||||
case httpx.Timeout():
|
||||
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||
case _:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -50,32 +50,33 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
||||
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
match prompt_template:
|
||||
case CompletionModelPromptTemplate():
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
case list() if all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
@ -61,14 +61,15 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
case sessionmaker():
|
||||
self._session_factory = session_factory
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
|
||||
@ -68,14 +68,15 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
case sessionmaker():
|
||||
self._session_factory = session_factory
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
|
||||
@ -288,24 +288,25 @@ class HumanInputFormRepositoryImpl:
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod):
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
recipients.extend(
|
||||
self._build_email_recipients(
|
||||
session=session,
|
||||
match delivery_method:
|
||||
case InteractiveSurfaceDeliveryMethod():
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients_config=email_recipients_config,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
case EmailDeliveryMethod():
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
recipients.extend(
|
||||
self._build_email_recipients(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients_config=email_recipients_config,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
||||
|
||||
|
||||
@ -54,14 +54,15 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
case sessionmaker():
|
||||
self._session_factory = session_factory
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
|
||||
@ -77,14 +77,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
match session_factory:
|
||||
case Engine():
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
case sessionmaker():
|
||||
self._session_factory = session_factory
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
|
||||
@ -125,10 +125,11 @@ class SchemaResolver:
|
||||
|
||||
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a single queue item"""
|
||||
if isinstance(item.current, dict):
|
||||
self._process_dict(queue, item)
|
||||
elif isinstance(item.current, list):
|
||||
self._process_list(queue, item)
|
||||
match item.current:
|
||||
case dict():
|
||||
self._process_dict(queue, item)
|
||||
case list():
|
||||
self._process_list(queue, item)
|
||||
|
||||
def _process_dict(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a dictionary item"""
|
||||
|
||||
@ -106,24 +106,26 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
||||
# Handle datetime fields
|
||||
started_at = data.get("started_at") or data.get("created_at")
|
||||
if started_at:
|
||||
if isinstance(started_at, str):
|
||||
model.created_at = datetime.fromisoformat(started_at)
|
||||
elif isinstance(started_at, (int, float)):
|
||||
model.created_at = datetime.fromtimestamp(started_at)
|
||||
else:
|
||||
model.created_at = started_at
|
||||
match started_at:
|
||||
case str():
|
||||
model.created_at = datetime.fromisoformat(started_at)
|
||||
case int() | float():
|
||||
model.created_at = datetime.fromtimestamp(started_at)
|
||||
case _:
|
||||
model.created_at = started_at
|
||||
else:
|
||||
# Provide default created_at if missing
|
||||
model.created_at = datetime.now()
|
||||
|
||||
finished_at = data.get("finished_at")
|
||||
if finished_at:
|
||||
if isinstance(finished_at, str):
|
||||
model.finished_at = datetime.fromisoformat(finished_at)
|
||||
elif isinstance(finished_at, (int, float)):
|
||||
model.finished_at = datetime.fromtimestamp(finished_at)
|
||||
else:
|
||||
model.finished_at = finished_at
|
||||
match finished_at:
|
||||
case str():
|
||||
model.finished_at = datetime.fromisoformat(finished_at)
|
||||
case int() | float():
|
||||
model.finished_at = datetime.fromtimestamp(finished_at)
|
||||
case _:
|
||||
model.finished_at = finished_at
|
||||
|
||||
# Compute elapsed_time from started_at and finished_at
|
||||
# LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity
|
||||
|
||||
Loading…
Reference in New Issue
Block a user