mirror of https://github.com/langgenius/dify.git
refactor: tool message transformer
This commit is contained in:
parent
4b4741f7ed
commit
c28998a6f0
|
|
@ -96,6 +96,9 @@ class ToolInvokeMessage(BaseModel):
|
|||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
|
@ -109,7 +112,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | None
|
||||
message: JsonMessage | TextMessage | BlobMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
save_as: str = ''
|
||||
|
||||
|
|
@ -321,7 +324,7 @@ class ToolRuntimeVariablePool(BaseModel):
|
|||
|
||||
self.pool.append(variable)
|
||||
|
||||
def set_file(self, tool_name: str, value: str, name: str = None) -> None:
|
||||
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
|
||||
|
|
|
|||
|
|
@ -80,8 +80,8 @@ class ToolFileManager:
|
|||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
|
|
@ -131,7 +131,7 @@ class ToolFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
|
|
@ -155,7 +155,7 @@ class ToolFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
|
|
@ -173,7 +173,7 @@ class ToolFileManager:
|
|||
tool_file_id = None
|
||||
|
||||
|
||||
tool_file: ToolFile = (
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
|
|
@ -197,7 +197,7 @@ class ToolFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
|
|
@ -13,7 +14,7 @@ class ToolFileMessageTransformer:
|
|||
def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
|
||||
conversation_id: Optional[str] = None) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
|
|
@ -25,18 +26,23 @@ class ToolFileMessageTransformer:
|
|||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
if not conversation_id:
|
||||
raise
|
||||
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
file_url=message.message
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
|
|
@ -44,57 +50,67 @@ class ToolFileMessageTransformer:
|
|||
logger.exception(e)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
message=ToolInvokeMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message}, you can try to download it yourself."
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
assert message.meta
|
||||
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
extension = guess_extension(file.mimetype) or ".bin"
|
||||
url = cls.get_tool_file_url(file.id, extension)
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var = message.meta.get('file_var')
|
||||
assert message.meta
|
||||
|
||||
file_var: FileVar | None = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file_var.related_id and file_var.extension
|
||||
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ class ToolNode(BaseNode):
|
|||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]):
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
|
|
@ -277,7 +278,7 @@ class ToolConversationVariables(db.Model):
|
|||
def variables(self) -> dict:
|
||||
return json.loads(self.variables_str)
|
||||
|
||||
class ToolFile(db.Model):
|
||||
class ToolFile(DeclarativeBase):
|
||||
"""
|
||||
store the file created by agent
|
||||
"""
|
||||
|
|
@ -288,16 +289,17 @@ class ToolFile(db.Model):
|
|||
db.Index('tool_file_conversation_id_idx', 'conversation_id'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
|
||||
# conversation user id
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[StringUUID] = mapped_column(StringUUID)
|
||||
# conversation id
|
||||
conversation_id = db.Column(StringUUID, nullable=True)
|
||||
conversation_id: Mapped[StringUUID] = mapped_column(nullable=True)
|
||||
# file key
|
||||
file_key = db.Column(db.String(255), nullable=False)
|
||||
file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
# mime type
|
||||
mimetype = db.Column(db.String(255), nullable=False)
|
||||
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
# original url
|
||||
original_url = db.Column(db.String(2048), nullable=True)
|
||||
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
|
||||
|
||||
Loading…
Reference in New Issue