diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 5789965747..4c246b230d 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ import os -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Mapping from typing import Any, Optional, TextIO, Union from pydantic import BaseModel @@ -55,7 +55,7 @@ class DifyAgentCallbackHandler(BaseModel): self, tool_name: str, tool_inputs: Mapping[str, Any], - tool_outputs: Sequence[ToolInvokeMessage], + tool_outputs: Iterable[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 9397f22494..0fb4470498 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,9 +1,9 @@ import json -from collections.abc import Generator, Mapping +from collections.abc import Generator, Iterable from copy import deepcopy from datetime import datetime, timezone from mimetypes import guess_type -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from yarl import URL @@ -40,7 +40,7 @@ class ToolEngine: user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: + ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -67,9 +67,9 @@ class ToolEngine: ) messages = ToolEngine._invoke(tool, tool_parameters, user_id) - invocation_meta_dict = {'meta': None} + invocation_meta_dict: dict[str, ToolInvokeMeta] = {} - def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]): + def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]): for message in messages: if isinstance(message, ToolInvokeMeta): invocation_meta_dict['meta'] = message @@ -136,7 +136,7 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], + def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any], user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, @@ -156,6 +156,7 @@ class ToolEngine: if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} + response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters) # hit the callback handler @@ -204,6 +205,9 @@ class ToolEngine: """ Invoke the tool with the given arguments. """ + if not tool.runtime: + raise ValueError("missing runtime in tool") + started_at = datetime.now(timezone.utc) meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ 'tool_name': tool.identity.name, @@ -223,42 +227,42 @@ class ToolEngine: yield meta @staticmethod - def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str: """ Handle tool response """ result = '' for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: - result += response.message + result += cast(ToolInvokeMessage.TextMessage, response.message).text elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it." elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." elif response.type == ToolInvokeMessage.MessageType.JSON: - result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." + result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}." else: result += f"tool response: {response.message}." return result @staticmethod - def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: + def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]: """ Extract tool response binary """ - result = [] - for response in tool_response: if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: mimetype = None + if not response.meta: + raise ValueError("missing meta data") if response.meta.get('mime_type'): mimetype = response.meta.get('mime_type') else: try: - url = URL(response.message) + url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) extension = url.suffix guess_type_result, _ = guess_type(f'a{extension}') if guess_type_result: @@ -269,35 +273,36 @@ class ToolEngine: if not mimetype: mimetype = 'image/jpeg' - result.append(ToolInvokeMessageBinary( + yield ToolInvokeMessageBinary( mimetype=response.meta.get('mime_type', 'image/jpeg'), - url=response.message, + url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, - )) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( + if not response.meta: + raise ValueError("missing meta data") + + yield ToolInvokeMessageBinary( mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, + url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, - )) + ) elif response.type == ToolInvokeMessage.MessageType.LINK: # check if there is a mime type in meta if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( + yield ToolInvokeMessageBinary( mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', - url=response.message, + url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, - )) - - return result + ) @staticmethod def _create_message_files( - tool_messages: list[ToolInvokeMessageBinary], + tool_messages: Iterable[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str - ) -> list[tuple[Any, str]]: + ) -> list[tuple[MessageFile, str]]: """ Create message file diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 47498f4f5f..6ba7e7e09b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from os import path from typing import Any, cast @@ -100,7 +100,7 @@ class ToolNode(BaseNode): variable_pool: VariablePool, node_data: ToolNodeData, for_log: bool = False, - ) -> Mapping[str, Any]: + ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -110,7 +110,7 @@ class ToolNode(BaseNode): node_data (ToolNodeData): The data associated with the tool node. Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. + dict[str, Any]: A dictionary containing the generated parameters. """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} diff --git a/api/models/base.py b/api/models/base.py new file mode 100644 index 0000000000..1c2dcc40b9 --- /dev/null +++ b/api/models/base.py @@ -0,0 +1,5 @@ +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass \ No newline at end of file diff --git a/api/models/model.py b/api/models/model.py index e2d1fcfc23..298bfbda12 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,6 +14,7 @@ from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string +from models.base import Base from .account import Account, Tenant from .types import StringUUID @@ -211,7 +212,7 @@ class App(db.Model): return tags if tags else [] -class AppModelConfig(db.Model): +class AppModelConfig(Base): __tablename__ = 'app_model_configs' __table_args__ = ( db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), @@ -550,6 +551,9 @@ class Conversation(db.Model): else: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() + + if not app_model_config: + raise ValueError("app config not found") model_config = app_model_config.to_dict() @@ -640,7 +644,7 @@ class Conversation(db.Model): return self.override_model_configs is not None -class Message(db.Model): +class Message(Base): __tablename__ = 'messages' __table_args__ = ( db.PrimaryKeyConstraint('id', name='message_pkey'), @@ -932,7 +936,7 @@ class MessageFeedback(db.Model): return account -class MessageFile(db.Model): +class MessageFile(Base): __tablename__ = 'message_files' __table_args__ = ( db.PrimaryKeyConstraint('id', name='message_file_pkey'), @@ -940,15 +944,15 @@ class MessageFile(db.Model): db.Index('message_file_created_by_idx', 'created_by') ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - transfer_method = db.Column(db.String(255), nullable=False) - url = db.Column(db.Text, nullable=True) - belongs_to = db.Column(db.String(255), nullable=True) - upload_file_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + url: Mapped[str] = mapped_column(db.Text, nullable=True) + belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True) + upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True) + created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/tools.py b/api/models/tools.py index 937481583a..1e7421622a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,12 +1,13 @@ import json from sqlalchemy import ForeignKey -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db +from models.base import Base from .model import Account, App, Tenant from .types import StringUUID @@ -277,9 +278,6 @@ class ToolConversationVariables(db.Model): @property def variables(self) -> dict: return json.loads(self.variables_str) - -class Base(DeclarativeBase): - pass class ToolFile(Base): """