diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 7d96db8652..c266db1cdb 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -96,6 +96,9 @@ class ToolInvokeMessage(BaseModel): class JsonMessage(BaseModel): json_object: dict + class BlobMessage(BaseModel): + blob: bytes + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -109,7 +112,7 @@ class ToolInvokeMessage(BaseModel): """ plain text, image url or link url """ - message: JsonMessage | TextMessage | None + message: JsonMessage | TextMessage | BlobMessage | None meta: dict[str, Any] | None = None save_as: str = '' @@ -321,7 +324,7 @@ class ToolRuntimeVariablePool(BaseModel): self.pool.append(variable) - def set_file(self, tool_name: str, value: str, name: str = None) -> None: + def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None: """ set an image variable diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f9f7c7d78a..078c58c662 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -80,8 +80,8 @@ class ToolFileManager: def create_file_by_url( user_id: str, tenant_id: str, - conversation_id: str, file_url: str, + conversation_id: Optional[str] = None, ) -> ToolFile: """ create file @@ -131,7 +131,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file: ToolFile | None = ( db.session.query(ToolFile) .filter( ToolFile.id == id, @@ -155,7 +155,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile = ( + message_file: MessageFile | None = ( db.session.query(MessageFile) .filter( MessageFile.id == id, @@ -173,7 +173,7 @@ class ToolFileManager: tool_file_id = None - tool_file: ToolFile = ( + tool_file: ToolFile | None = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, @@ -197,7 +197,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file: ToolFile | None = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index d22b88e58c..41a93a4f95 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,8 +1,9 @@ import logging from collections.abc import Generator from mimetypes import guess_extension +from typing import Optional -from core.file.file_obj import FileTransferMethod, FileType +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -13,7 +14,7 @@ class ToolFileMessageTransformer: def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None], user_id: str, tenant_id: str, - conversation_id: str) -> Generator[ToolInvokeMessage, None, None]: + conversation_id: Optional[str] = None) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download """ @@ -25,18 +26,23 @@ class ToolFileMessageTransformer: elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image try: + if not conversation_id: + raise + + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + file = ToolFileManager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, + file_url=message.message.text, conversation_id=conversation_id, - file_url=message.message ) url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, + message=ToolInvokeMessage.TextMessage(text=url), save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) @@ -44,57 +50,67 @@ class ToolFileMessageTransformer: logger.exception(e) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", + message=ToolInvokeMessage.TextMessage( + text=f"Failed to download image: {message.message}, you can try to download it yourself." + ), meta=message.meta.copy() if message.meta is not None else {}, save_as=message.save_as, ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage + assert message.meta + mimetype = message.meta.get('mime_type', 'octet/stream') # if message is str, encode it to bytes - if isinstance(message.message, str): - message.message = message.message.encode('utf-8') + + if not isinstance(message.message, ToolInvokeMessage.BlobMessage): + raise ValueError("unexpected message type") file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, - file_binary=message.message, + file_binary=message.message.blob, mimetype=mimetype ) - url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) + extension = guess_extension(file.mimetype) or ".bin" + url = cls.get_tool_file_url(file.id, extension) # check if file is image if 'image' in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, + message=ToolInvokeMessage.TextMessage(text=url), save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, - message=url, + message=ToolInvokeMessage.TextMessage(text=url), save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get('file_var') + assert message.meta + + file_var: FileVar | None = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: + assert file_var.related_id and file_var.extension + url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, + message=ToolInvokeMessage.TextMessage(text=url), save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, - message=url, + message=ToolInvokeMessage.TextMessage(text=url), save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 52a4a216e0..0774fc4f3d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from os import path from typing import Any, cast @@ -145,7 +145,7 @@ class ToolNode(BaseNode): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): + def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]): """ Convert ToolInvokeMessages into tuple[plain_text, files] """ diff --git a/api/models/tools.py b/api/models/tools.py index 069dc5bad0..3ee246eeb3 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,7 @@ import json from sqlalchemy import ForeignKey +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -277,7 +278,7 @@ class ToolConversationVariables(db.Model): def variables(self) -> dict: return json.loads(self.variables_str) -class ToolFile(db.Model): +class ToolFile(DeclarativeBase): """ store the file created by agent """ @@ -288,16 +289,17 @@ class ToolFile(db.Model): db.Index('tool_file_conversation_id_idx', 'conversation_id'), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) # conversation user id - user_id = db.Column(StringUUID, nullable=False) + user_id: Mapped[str] = mapped_column(StringUUID) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[StringUUID] = mapped_column(StringUUID) # conversation id - conversation_id = db.Column(StringUUID, nullable=True) + conversation_id: Mapped[StringUUID] = mapped_column(nullable=True) # file key - file_key = db.Column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) # mime type - mimetype = db.Column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) # original url - original_url = db.Column(db.String(2048), nullable=True) \ No newline at end of file + original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + \ No newline at end of file