From 18abc6658589e6755cee42bd49a33dc9ea678bda Mon Sep 17 00:00:00 2001 From: Novice Date: Fri, 16 Jan 2026 17:01:19 +0800 Subject: [PATCH] feat: add context file support --- api/core/file/file_manager.py | 146 +++++++++++++- api/core/memory/node_token_buffer_memory.py | 31 ++- .../entities/message_entities.py | 3 + api/core/workflow/nodes/llm/llm_utils.py | 54 +++++- api/core/workflow/nodes/llm/node.py | 4 +- .../unit_tests/core/file/test_file_manager.py | 182 ++++++++++++++++++ .../core/workflow/nodes/llm/test_llm_utils.py | 174 +++++++++++++++++ 7 files changed, 585 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/core/file/test_file_manager.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 120fb73cdb..93c1a9be99 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,4 +1,5 @@ import base64 +import logging from collections.abc import Mapping from configs import dify_config @@ -10,7 +11,10 @@ from core.model_runtime.entities import ( TextPromptMessageContent, VideoPromptMessageContent, ) -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from core.model_runtime.entities.message_entities import ( + MultiModalPromptMessageContent, + PromptMessageContentUnionTypes, +) from core.tools.signature import sign_tool_file from extensions.ext_storage import storage @@ -18,6 +22,8 @@ from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType +logger = logging.getLogger(__name__) + def get_attr(*, file: File, attr: FileAttribute): match attr: @@ -89,6 +95,8 @@ def to_prompt_message_content( "format": f.extension.removeprefix("."), "mime_type": f.mime_type, "filename": f.filename or "", + # Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url" + "file_ref": _encode_file_ref(f), } if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW @@ -96,6 +104,17 @@ def to_prompt_message_content( return prompt_class_map[f.type].model_validate(params) +def _encode_file_ref(f: File) -> str | None: + """Encode file reference as 'transfer_method:id_or_url' string.""" + if f.transfer_method == FileTransferMethod.REMOTE_URL: + return f"remote:{f.remote_url}" if f.remote_url else None + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + return f"local:{f.related_id}" if f.related_id else None + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + return f"tool:{f.related_id}" if f.related_id else None + return None + + def download(f: File, /): if f.transfer_method in ( FileTransferMethod.TOOL_FILE, @@ -164,3 +183,128 @@ def _to_url(f: File, /): return sign_tool_file(tool_file_id=f.related_id, extension=f.extension) else: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +def restore_multimodal_content( + content: MultiModalPromptMessageContent, +) -> MultiModalPromptMessageContent: + """ + Restore base64_data or url for multimodal content from file_ref. + + file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...") + + Args: + content: MultiModalPromptMessageContent with file_ref field + + Returns: + MultiModalPromptMessageContent with restored base64_data or url + """ + # Skip if no file reference or content already has data + if not content.file_ref: + return content + if content.base64_data or content.url: + return content + + try: + file = _build_file_from_ref( + file_ref=content.file_ref, + file_format=content.format, + mime_type=content.mime_type, + filename=content.filename, + ) + if not file: + return content + + # Restore content based on config + if dify_config.MULTIMODAL_SEND_FORMAT == "base64": + restored_base64 = _get_encoded_string(file) + return content.model_copy(update={"base64_data": restored_base64}) + else: + restored_url = _to_url(file) + return content.model_copy(update={"url": restored_url}) + + except Exception as e: + logger.warning("Failed to restore multimodal content: %s", e) + return content + + +def _build_file_from_ref( + file_ref: str, + file_format: str | None, + mime_type: str | None, + filename: str | None, +) -> File | None: + """ + Build a File object from encoded file_ref string. + + Args: + file_ref: Encoded reference "transfer_method:id_or_url" + file_format: The file format/extension (without dot) + mime_type: The mime type + filename: The filename + + Returns: + File object with storage_key loaded, or None if not found + """ + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from models.model import UploadFile + from models.tools import ToolFile + + # Parse file_ref: "method:value" + if ":" not in file_ref: + logger.warning("Invalid file_ref format: %s", file_ref) + return None + + method, value = file_ref.split(":", 1) + extension = f".{file_format}" if file_format else None + + if method == "remote": + return File( + tenant_id="", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=value, + extension=extension, + mime_type=mime_type, + filename=filename, + storage_key="", + ) + + # Query database for storage_key + with Session(db.engine) as session: + if method == "local": + stmt = select(UploadFile).where(UploadFile.id == value) + upload_file = session.scalar(stmt) + if upload_file: + return File( + tenant_id=upload_file.tenant_id, + type=FileType(upload_file.extension) + if hasattr(FileType, upload_file.extension.upper()) + else FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id=value, + extension=extension or ("." + upload_file.extension if upload_file.extension else None), + mime_type=mime_type or upload_file.mime_type, + filename=filename or upload_file.name, + storage_key=upload_file.key, + ) + elif method == "tool": + stmt = select(ToolFile).where(ToolFile.id == value) + tool_file = session.scalar(stmt) + if tool_file: + return File( + tenant_id=tool_file.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=value, + extension=extension, + mime_type=mime_type or tool_file.mimetype, + filename=filename or tool_file.name, + storage_key=tool_file.file_key, + ) + + logger.warning("File not found for file_ref: %s", file_ref) + return None diff --git a/api/core/memory/node_token_buffer_memory.py b/api/core/memory/node_token_buffer_memory.py index 386dde9c89..ec6b04b13e 100644 --- a/api/core/memory/node_token_buffer_memory.py +++ b/api/core/memory/node_token_buffer_memory.py @@ -15,20 +15,24 @@ Design: import logging from collections.abc import Sequence +from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session +from core.file import file_manager from core.memory.base import BaseMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, + MultiModalPromptMessageContent, PromptMessage, PromptMessageRole, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import Message @@ -108,11 +112,36 @@ class NodeTokenBufferMemory(BaseMemory): messages = [] for msg_dict in context_data: try: - messages.append(self._deserialize_prompt_message(msg_dict)) + msg = self._deserialize_prompt_message(msg_dict) + msg = self._restore_multimodal_content(msg) + messages.append(msg) except Exception as e: logger.warning("Failed to deserialize prompt message: %s", e) return messages + def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage: + """ + Restore multimodal content (base64 or url) from file_ref. + + When context is saved, base64_data is cleared to save storage space. + This method restores the content by parsing file_ref (format: "method:id_or_url"). + """ + content = message.content + if content is None or isinstance(content, str): + return message + + # Process list content, restoring multimodal data from file references + restored_content: list[PromptMessageContentUnionTypes] = [] + for item in content: + if isinstance(item, MultiModalPromptMessageContent): + # restore_multimodal_content preserves the concrete subclass type + restored_item = file_manager.restore_multimodal_content(item) + restored_content.append(cast(PromptMessageContentUnionTypes, restored_item)) + else: + restored_content.append(item) + + return message.model_copy(update={"content": restored_content}) + def get_history_prompt_messages( self, *, diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 5a07a22023..284f4dba01 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -91,6 +91,9 @@ class MultiModalPromptMessageContent(PromptMessageContent): mime_type: str = Field(default=..., description="the mime type of multi-modal file") filename: str = Field(default="", description="the filename of multi-modal file") + # File reference for context restoration, format: "transfer_method:related_id" or "remote:url" + file_ref: str | None = Field(default=None, description="Encoded file reference for restoration") + @property def data(self): return self.url or f"data:{self.mime_type};base64,{self.base64_data}" diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 1b412df0ea..966c34a0d7 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -233,21 +233,63 @@ def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage: """ Truncate multi-modal content base64 data in a message to avoid storing large data. Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility. + + If file_ref is present, clears base64_data and url (they can be restored later). + Otherwise, truncates base64_data as fallback for legacy data. """ content = message.content if content is None or isinstance(content, str): return message - # Process list content, truncating multi-modal base64 data + # Process list content, handling multi-modal data based on file_ref availability new_content: list[PromptMessageContentUnionTypes] = [] for item in content: if isinstance(item, MultiModalPromptMessageContent): - # Truncate base64_data similar to prompt_messages_to_prompt_for_saving - truncated_base64 = "" - if item.base64_data: - truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:] - new_content.append(item.model_copy(update={"base64_data": truncated_base64})) + if item.file_ref: + # Clear base64 and url, keep file_ref for later restoration + new_content.append(item.model_copy(update={"base64_data": "", "url": ""})) + else: + # Fallback: truncate base64_data if no file_ref (legacy data) + truncated_base64 = "" + if item.base64_data: + truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:] + new_content.append(item.model_copy(update={"base64_data": truncated_base64})) else: new_content.append(item) return message.model_copy(update={"content": new_content}) + + +def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Restore multimodal content (base64 or url) in a list of PromptMessages. + + When context is saved, base64_data is cleared to save storage space. + This function restores the content by parsing file_ref in each MultiModalPromptMessageContent. + + Args: + messages: List of PromptMessages that may contain truncated multimodal content + + Returns: + List of PromptMessages with restored multimodal content + """ + from core.file import file_manager + + return [_restore_message_content(msg, file_manager) for msg in messages] + + +def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage: + """Restore multimodal content in a single PromptMessage.""" + content = message.content + if content is None or isinstance(content, str): + return message + + restored_content: list[PromptMessageContentUnionTypes] = [] + for item in content: + if isinstance(item, MultiModalPromptMessageContent): + restored_item = file_manager.restore_multimodal_content(item) + restored_content.append(cast(PromptMessageContentUnionTypes, restored_item)) + else: + restored_content.append(item) + + return message.model_copy(update={"content": restored_content}) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index cce0e0679a..bde43d8f08 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -675,7 +675,9 @@ class LLMNode(Node[LLMNodeData]): raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found") if not isinstance(ctx_var, ArrayPromptMessageSegment): raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]") - combined_messages.extend(ctx_var.value) + # Restore multimodal content (base64/url) that was truncated when saving context + restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value) + combined_messages.extend(restored_messages) context_idx += 1 else: # Handle static message diff --git a/api/tests/unit_tests/core/file/test_file_manager.py b/api/tests/unit_tests/core/file/test_file_manager.py new file mode 100644 index 0000000000..018bdee4d7 --- /dev/null +++ b/api/tests/unit_tests/core/file/test_file_manager.py @@ -0,0 +1,182 @@ +"""Tests for file_manager module, specifically multimodal content handling.""" + +from unittest.mock import patch + +from core.file import File, FileTransferMethod, FileType +from core.file.file_manager import ( + _encode_file_ref, + restore_multimodal_content, + to_prompt_message_content, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + + +class TestEncodeFileRef: + """Tests for _encode_file_ref function.""" + + def test_encodes_local_file(self): + """Local file should be encoded as 'local:id'.""" + file = File( + tenant_id="t", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="abc123", + storage_key="key", + ) + assert _encode_file_ref(file) == "local:abc123" + + def test_encodes_tool_file(self): + """Tool file should be encoded as 'tool:id'.""" + file = File( + tenant_id="t", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="xyz789", + storage_key="key", + ) + assert _encode_file_ref(file) == "tool:xyz789" + + def test_encodes_remote_url(self): + """Remote URL should be encoded as 'remote:url'.""" + file = File( + tenant_id="t", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.png", + storage_key="", + ) + assert _encode_file_ref(file) == "remote:https://example.com/image.png" + + +class TestToPromptMessageContent: + """Tests for to_prompt_message_content function with file_ref field.""" + + @patch("core.file.file_manager.dify_config") + @patch("core.file.file_manager._get_encoded_string") + def test_includes_file_ref(self, mock_get_encoded, mock_config): + """Generated content should include file_ref field.""" + mock_config.MULTIMODAL_SEND_FORMAT = "base64" + mock_get_encoded.return_value = "base64data" + + file = File( + id="test-message-file-id", + tenant_id="test-tenant", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test-related-id", + remote_url=None, + extension=".png", + mime_type="image/png", + filename="test.png", + storage_key="test-key", + ) + + result = to_prompt_message_content(file) + + assert isinstance(result, ImagePromptMessageContent) + assert result.file_ref == "local:test-related-id" + assert result.base64_data == "base64data" + + +class TestRestoreMultimodalContent: + """Tests for restore_multimodal_content function.""" + + def test_returns_content_unchanged_when_no_file_ref(self): + """Content without file_ref should pass through unchanged.""" + content = ImagePromptMessageContent( + format="png", + base64_data="existing-data", + mime_type="image/png", + file_ref=None, + ) + + result = restore_multimodal_content(content) + + assert result.base64_data == "existing-data" + + def test_returns_content_unchanged_when_already_has_data(self): + """Content that already has base64_data should not be reloaded.""" + content = ImagePromptMessageContent( + format="png", + base64_data="existing-data", + mime_type="image/png", + file_ref="local:file-id", + ) + + result = restore_multimodal_content(content) + + assert result.base64_data == "existing-data" + + def test_returns_content_unchanged_when_already_has_url(self): + """Content that already has url should not be reloaded.""" + content = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + file_ref="local:file-id", + ) + + result = restore_multimodal_content(content) + + assert result.url == "https://example.com/image.png" + + @patch("core.file.file_manager.dify_config") + @patch("core.file.file_manager._build_file_from_ref") + @patch("core.file.file_manager._to_url") + def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config): + """Content should be restored from file_ref when url is empty (url mode).""" + mock_config.MULTIMODAL_SEND_FORMAT = "url" + mock_build_file.return_value = "mock_file" + mock_to_url.return_value = "https://restored-url.com/image.png" + + content = ImagePromptMessageContent( + format="png", + base64_data="", + url="", + mime_type="image/png", + filename="test.png", + file_ref="local:test-file-id", + ) + + result = restore_multimodal_content(content) + + assert result.url == "https://restored-url.com/image.png" + mock_build_file.assert_called_once() + + @patch("core.file.file_manager.dify_config") + @patch("core.file.file_manager._build_file_from_ref") + @patch("core.file.file_manager._get_encoded_string") + def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config): + """Content should be restored as base64 when in base64 mode.""" + mock_config.MULTIMODAL_SEND_FORMAT = "base64" + mock_build_file.return_value = "mock_file" + mock_get_encoded.return_value = "restored-base64-data" + + content = ImagePromptMessageContent( + format="png", + base64_data="", + url="", + mime_type="image/png", + filename="test.png", + file_ref="local:test-file-id", + ) + + result = restore_multimodal_content(content) + + assert result.base64_data == "restored-base64-data" + mock_build_file.assert_called_once() + + def test_handles_invalid_file_ref_gracefully(self): + """Invalid file_ref format should be handled gracefully.""" + content = ImagePromptMessageContent( + format="png", + base64_data="", + url="", + mime_type="image/png", + file_ref="invalid_format_no_colon", + ) + + result = restore_multimodal_content(content) + + # Should return unchanged on error + assert result.base64_data == "" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py new file mode 100644 index 0000000000..e327e03159 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -0,0 +1,174 @@ +"""Tests for llm_utils module, specifically multimodal content handling.""" + +import string +from unittest.mock import patch + +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) +from core.workflow.nodes.llm.llm_utils import ( + _truncate_multimodal_content, + build_context, + restore_multimodal_content_in_messages, +) + + +class TestTruncateMultimodalContent: + """Tests for _truncate_multimodal_content function.""" + + def test_returns_message_unchanged_for_string_content(self): + """String content should pass through unchanged.""" + message = UserPromptMessage(content="Hello, world!") + result = _truncate_multimodal_content(message) + assert result.content == "Hello, world!" + + def test_returns_message_unchanged_for_none_content(self): + """None content should pass through unchanged.""" + message = UserPromptMessage(content=None) + result = _truncate_multimodal_content(message) + assert result.content is None + + def test_clears_base64_when_file_ref_present(self): + """When file_ref is present, base64_data and url should be cleared.""" + image_content = ImagePromptMessageContent( + format="png", + base64_data=string.ascii_lowercase, + url="https://example.com/image.png", + mime_type="image/png", + filename="test.png", + file_ref="local:test-file-id", + ) + message = UserPromptMessage(content=[image_content]) + + result = _truncate_multimodal_content(message) + + assert isinstance(result.content, list) + assert len(result.content) == 1 + result_content = result.content[0] + assert isinstance(result_content, ImagePromptMessageContent) + assert result_content.base64_data == "" + assert result_content.url == "" + # file_ref should be preserved + assert result_content.file_ref == "local:test-file-id" + + def test_truncates_base64_when_no_file_ref(self): + """When file_ref is missing (legacy), base64_data should be truncated.""" + long_base64 = "a" * 100 + image_content = ImagePromptMessageContent( + format="png", + base64_data=long_base64, + mime_type="image/png", + filename="test.png", + file_ref=None, + ) + message = UserPromptMessage(content=[image_content]) + + result = _truncate_multimodal_content(message) + + assert isinstance(result.content, list) + result_content = result.content[0] + assert isinstance(result_content, ImagePromptMessageContent) + # Should be truncated with marker + assert "...[TRUNCATED]..." in result_content.base64_data + assert len(result_content.base64_data) < len(long_base64) + + def test_preserves_text_content(self): + """Text content should pass through unchanged.""" + text_content = TextPromptMessageContent(data="Hello!") + image_content = ImagePromptMessageContent( + format="png", + base64_data="test123", + mime_type="image/png", + file_ref="local:file-id", + ) + message = UserPromptMessage(content=[text_content, image_content]) + + result = _truncate_multimodal_content(message) + + assert isinstance(result.content, list) + assert len(result.content) == 2 + # Text content unchanged + assert result.content[0].data == "Hello!" + # Image content base64 cleared + assert result.content[1].base64_data == "" + + +class TestBuildContext: + """Tests for build_context function.""" + + def test_excludes_system_messages(self): + """System messages should be excluded from context.""" + from core.model_runtime.entities.message_entities import SystemPromptMessage + + messages = [ + SystemPromptMessage(content="You are a helpful assistant."), + UserPromptMessage(content="Hello!"), + ] + + context = build_context(messages, "Hi there!") + + # Should have user message + assistant response, no system message + assert len(context) == 2 + assert context[0].content == "Hello!" + assert context[1].content == "Hi there!" + + def test_appends_assistant_response(self): + """Assistant response should be appended to context.""" + messages = [UserPromptMessage(content="What is 2+2?")] + + context = build_context(messages, "The answer is 4.") + + assert len(context) == 2 + assert context[1].content == "The answer is 4." + + +class TestRestoreMultimodalContentInMessages: + """Tests for restore_multimodal_content_in_messages function.""" + + @patch("core.file.file_manager.restore_multimodal_content") + def test_restores_multimodal_content(self, mock_restore): + """Should restore multimodal content in messages.""" + # Setup mock + restored_content = ImagePromptMessageContent( + format="png", + base64_data="restored-base64", + mime_type="image/png", + file_ref="local:abc123", + ) + mock_restore.return_value = restored_content + + # Create message with truncated content + truncated_content = ImagePromptMessageContent( + format="png", + base64_data="", + mime_type="image/png", + file_ref="local:abc123", + ) + message = UserPromptMessage(content=[truncated_content]) + + result = restore_multimodal_content_in_messages([message]) + + assert len(result) == 1 + assert result[0].content[0].base64_data == "restored-base64" + mock_restore.assert_called_once() + + def test_passes_through_string_content(self): + """String content should pass through unchanged.""" + message = UserPromptMessage(content="Hello!") + + result = restore_multimodal_content_in_messages([message]) + + assert len(result) == 1 + assert result[0].content == "Hello!" + + def test_passes_through_text_content(self): + """TextPromptMessageContent should pass through unchanged.""" + text_content = TextPromptMessageContent(data="Hello!") + message = UserPromptMessage(content=[text_content]) + + result = restore_multimodal_content_in_messages([message]) + + assert len(result) == 1 + assert result[0].content[0].data == "Hello!"