dify/api/core/tools/tool_engine.py
EvanYao826 fe29a6585e fix: remove premature db.session operations and validate code language
Remove premature db.session.commit() and db.session.close() calls that
break transaction isolation and cause DetachedInstanceError:

- DatasetRetrieval._on_query: remove commit(), stage audit rows on the
  scoped session for the caller to commit (#37886)
- cloud_edition_billing_rate_limit_check: remove commit() after adding
  rate limit log, upstream session state is no longer flushed (#37885)
- ToolEngine._create_message_files: replace commit()+close() with
  flush() to obtain IDs without ending the transaction (#37884)

Add defensive validation for unsupported code languages in
CodeExecutor.execute_code to prevent sending null language to the
sandbox service (#37874)
2026-06-25 14:16:24 +08:00

390 lines
15 KiB
Python

import contextlib
import json
import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
from typing import Any, Union, cast
from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolInvokeMeta,
ToolParameter,
)
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from graphon.file import FileTransferMethod, FileType
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict[str, Any]],
user_id: str,
tenant_id: str,
message: Message,
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: TraceQueueManager | None = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters)
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict[str, ToolInvokeMeta],
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict["meta"] = message
else:
yield message
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id,
)
message_list = list(messages)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
meta = invocation_meta_dict["meta"]
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text,
message_id=message.id,
trace_manager=trace_manager,
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def generic_invoke(
tool: Tool,
tool_parameters: dict[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Workflow invokes the tool with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
# hit the callback handler
response = workflow_tool_callback.on_tool_execution(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response,
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
error=None,
tool_config={
"tool_name": tool.entity.identity.name,
"tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.entity.identity.icon,
},
)
try:
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
elif response.type == ToolInvokeMessage.MessageType.LINK:
parts.append(
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
parts.append(
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
)
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
continue
else:
parts.append(str(response.message))
# Add JSON parts, avoiding duplicates from text parts.
if json_parts:
existing_parts = set(parts)
parts.extend(p for p in json_parts if p not in existing_parts)
return "".join(parts)
@staticmethod
def _extract_tool_response_binary_and_text(
tool_response: list[ToolInvokeMessage],
) -> Generator[ToolInvokeMessageBinary, None, None]:
"""
Extract tool response binary
"""
for response in tool_response:
if response.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.IMAGE,
ToolInvokeMessage.MessageType.BINARY_LINK,
}:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
else:
with contextlib.suppress(Exception):
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
mimetype = guess_type_result
if not mimetype:
mimetype = (
"image/jpeg"
if response.type != ToolInvokeMessage.MessageType.BINARY_LINK
else "application/octet-stream"
)
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", mimetype),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and "mime_type" in response.meta:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
@staticmethod
def _create_message_files(
tool_messages: Iterable[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[str]:
"""
Create message file records for binary tool outputs.
Stages ``MessageFile`` rows on the current scoped session and flushes
to obtain generated IDs, but does **not** commit or close the session.
The caller (or Flask request teardown) is responsible for the final
commit so that the file records participate in the same transaction as
the surrounding agent workflow. This prevents:
* Partial commits when a downstream step fails after files are written.
* ``DetachedInstanceError`` on lazy-loaded ORM properties caused by an
early ``db.session.close()``.
"""
result = []
for message in tool_messages:
if "image" in message.mimetype:
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0]
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(
CreatorUserRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.flush()
result.append(str(message_file.id))
return result