mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: added test for core token buffer memory and model runtime (#32512)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
parent
60fe5e7f00
commit
e99628b76f
@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@ -4,7 +4,8 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@ -282,7 +282,8 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
969
api/tests/unit_tests/core/memory/test_token_buffer_memory.py
Normal file
969
api/tests/unit_tests/core/memory/test_token_buffer_memory.py
Normal file
@ -0,0 +1,969 @@
|
||||
"""Comprehensive unit tests for core/memory/token_buffer_memory.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from dify_graph.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock:
|
||||
"""Return a minimal Conversation mock."""
|
||||
conv = MagicMock()
|
||||
conv.id = str(uuid4())
|
||||
conv.mode = mode
|
||||
conv.model_config = {}
|
||||
return conv
|
||||
|
||||
|
||||
def _make_model_instance() -> MagicMock:
|
||||
"""Return a ModelInstance mock whose token counter returns a constant."""
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.return_value = 100
|
||||
return mi
|
||||
|
||||
|
||||
def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.id = str(uuid4())
|
||||
msg.query = "user query"
|
||||
msg.answer = answer
|
||||
msg.answer_tokens = answer_tokens
|
||||
msg.workflow_run_id = str(uuid4())
|
||||
msg.created_at = MagicMock()
|
||||
return msg
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for __init__ and workflow_run_repo property
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_init_stores_conversation_and_model_instance(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
assert mem.conversation is conv
|
||||
assert mem.model_instance is mi
|
||||
assert mem._workflow_run_repo is None
|
||||
|
||||
def test_workflow_run_repo_is_created_lazily(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
mock_repo = MagicMock()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm,
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
repo = mem.workflow_run_repo
|
||||
assert repo is mock_repo
|
||||
assert mem._workflow_run_repo is mock_repo
|
||||
|
||||
def test_workflow_run_repo_cached_after_first_access(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
existing_repo = MagicMock()
|
||||
mem._workflow_run_repo = existing_repo
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
|
||||
) as mock_factory:
|
||||
repo = mem.workflow_run_repo
|
||||
mock_factory.assert_not_called()
|
||||
assert repo is existing_repo
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for _build_prompt_message_with_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBuildPromptMessageWithFiles:
|
||||
"""Tests for the private _build_prompt_message_with_files method."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_no_files_user_message(self, mode):
|
||||
"""When file_extra_config is falsy or app_record is None → plain UserPromptMessage."""
|
||||
conv = _make_conversation(mode)
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None, # falsy → file_objs = []
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="hello",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "hello"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_no_files_assistant_message(self, mode):
|
||||
"""Plain AssistantPromptMessage when no files and is_user_message=False."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="ai reply",
|
||||
message=_make_message(),
|
||||
app_record=None,
|
||||
is_user_message=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AssistantPromptMessage)
|
||||
assert result.content == "ai reply"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_user_message(self, mode):
|
||||
"""When files are present, returns UserPromptMessage with list content."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = None # no detail override
|
||||
|
||||
mock_file_obj = MagicMock()
|
||||
# Must be a real entity so Pydantic's tagged union discriminator can validate it
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
mock_message_file = MagicMock()
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=mock_file_obj,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
),
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[mock_message_file],
|
||||
text_content="user text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert isinstance(result.content, list)
|
||||
# Last element should be TextPromptMessageContent
|
||||
assert isinstance(result.content[-1], TextPromptMessageContent)
|
||||
assert result.content[-1].data == "user text"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_assistant_message(self, mode):
|
||||
"""When files are present, returns AssistantPromptMessage with list content."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = None
|
||||
|
||||
mock_file_obj = MagicMock()
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=mock_file_obj,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
),
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="ai text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AssistantPromptMessage)
|
||||
assert isinstance(result.content, list)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_image_detail_overridden(self, mode):
|
||||
"""When image_config.detail is set, detail is taken from config."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_image_config = MagicMock()
|
||||
mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = mock_image_config
|
||||
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
) as mock_to_prompt,
|
||||
):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="user text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
# Ensure the LOW detail was passed through
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode):
|
||||
"""app_record=None path → file_objs stays empty → plain messages."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="hello",
|
||||
message=_make_message(),
|
||||
app_record=None, # <-- forces the else branch → file_objs = []
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "hello"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mode: ADVANCED_CHAT / WORKFLOW
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_no_app_raises(self, mode):
|
||||
"""Raises ValueError when conversation.app is falsy."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = None
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(ValueError, match="App not found for conversation"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_no_workflow_run_id_raises(self, mode):
|
||||
"""Raises ValueError when message.workflow_run_id is falsy."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
message = _make_message()
|
||||
message.workflow_run_id = None # force missing
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow run ID not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=message,
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_workflow_run_not_found_raises(self, mode):
|
||||
"""Raises ValueError when workflow_run_repo returns None."""
|
||||
conv = _make_conversation(mode)
|
||||
mock_app = MagicMock()
|
||||
conv.app = mock_app
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow run not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_workflow_not_found_raises(self, mode):
|
||||
"""Raises ValueError when Workflow lookup returns None."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow_id = str(uuid4())
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = None # workflow not found
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_success_no_files_user(self, mode):
|
||||
"""Happy path: workflow mode, no message files → plain UserPromptMessage."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow_id = str(uuid4())
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.features_dict = {}
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="wf text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "wf text"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Invalid mode
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_invalid_mode_raises_assertion(self):
|
||||
"""Any unknown AppMode raises AssertionError."""
|
||||
conv = _make_conversation()
|
||||
conv.mode = "unknown_mode" # not in any set
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(AssertionError, match="Invalid app mode"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for get_history_prompt_messages
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetHistoryPromptMessages:
|
||||
"""Tests for get_history_prompt_messages."""
|
||||
|
||||
def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory:
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
def test_returns_empty_when_no_messages(self):
|
||||
mem = self._make_memory()
|
||||
with patch("core.memory.token_buffer_memory.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
result = mem.get_history_prompt_messages()
|
||||
assert result == []
|
||||
|
||||
def test_skips_first_message_without_answer(self):
|
||||
"""The newest message (index 0 after extraction) without answer and tokens==0 is skipped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg_no_answer = _make_message(answer="", answer_tokens=0)
|
||||
msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg_no_answer],
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.return_value.all.side_effect = [
|
||||
[msg_no_answer], # first call: messages query
|
||||
[], # second call: user files query (never hit, but safe)
|
||||
]
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_message_with_answer_not_skipped(self):
|
||||
"""A message with a non-empty answer is NOT popped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg = _make_message(answer="some answer", answer_tokens=10)
|
||||
msg.parent_message_id = None
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# user files query → empty; assistant files query → empty
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert len(result) == 2 # one user + one assistant
|
||||
|
||||
def test_message_limit_default_is_500(self):
|
||||
"""When message_limit is None the stmt is limited to 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=None)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_message_limit_clipped_to_500(self):
|
||||
"""A message_limit > 500 is clamped to 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=9999)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_message_limit_positive_used(self):
|
||||
"""A positive message_limit < 500 is used as-is."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=10)
|
||||
mock_stmt.limit.assert_called_with(10)
|
||||
|
||||
def test_message_limit_zero_uses_default(self):
|
||||
"""message_limit=0 triggers the else branch → default 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=0)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_user_files_cause_build_with_files_call(self):
|
||||
"""When user_files is non-empty _build_prompt_message_with_files is invoked."""
|
||||
mem = self._make_memory()
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
mock_user_file = MagicMock()
|
||||
mock_user_prompt = UserPromptMessage(content="from build")
|
||||
mock_assistant_prompt = AssistantPromptMessage(content="answer")
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
# messages query
|
||||
r.all.return_value = [msg]
|
||||
elif call_count["n"] == 1:
|
||||
# user files
|
||||
r.all.return_value = [mock_user_file]
|
||||
else:
|
||||
# assistant files
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch.object(
|
||||
mem,
|
||||
"_build_prompt_message_with_files",
|
||||
side_effect=[mock_user_prompt, mock_assistant_prompt],
|
||||
) as mock_build,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert mock_build.call_count >= 1
|
||||
# First call should be user message
|
||||
first_call_kwargs = mock_build.call_args_list[0][1]
|
||||
assert first_call_kwargs["is_user_message"] is True
|
||||
|
||||
def test_assistant_files_cause_build_with_files_call(self):
|
||||
"""When assistant_files is non-empty, build is called with is_user_message=False."""
|
||||
mem = self._make_memory()
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
mock_assistant_file = MagicMock()
|
||||
mock_user_prompt = UserPromptMessage(content="query")
|
||||
mock_assistant_prompt = AssistantPromptMessage(content="built")
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
elif call_count["n"] == 1:
|
||||
r.all.return_value = [] # no user files
|
||||
else:
|
||||
r.all.return_value = [mock_assistant_file]
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch.object(
|
||||
mem,
|
||||
"_build_prompt_message_with_files",
|
||||
return_value=mock_assistant_prompt,
|
||||
) as mock_build,
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
mock_build.assert_called_once()
|
||||
call_kwargs = mock_build.call_args[1]
|
||||
assert call_kwargs["is_user_message"] is False
|
||||
|
||||
def test_token_pruning_removes_oldest_messages(self):
|
||||
"""If tokens exceed limit, oldest messages are removed until within limit."""
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
|
||||
# Model returns tokens that decrease only after removing pairs
|
||||
token_values = [3000, 1500] # first call over limit, second within
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.side_effect = token_values
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=2000)
|
||||
|
||||
# After pruning, we should have fewer than the 2 initial messages
|
||||
assert len(result) <= 1
|
||||
|
||||
def test_token_pruning_stops_at_single_message(self):
|
||||
"""Pruning stops when only 1 message remains (to prevent empty list)."""
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
|
||||
# Always over limit
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.return_value = 99999
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=1)
|
||||
|
||||
# At least 1 message should remain
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_no_pruning_when_within_limit(self):
|
||||
"""When tokens ≤ limit, no pruning occurs."""
|
||||
mem = self._make_memory()
|
||||
mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=2000)
|
||||
|
||||
assert len(result) == 2 # user + assistant
|
||||
|
||||
def test_plain_user_and_assistant_messages_returned(self):
|
||||
"""Without files, plain UserPromptMessage and AssistantPromptMessage appear."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg = _make_message(answer="My answer")
|
||||
msg.query = "My query"
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert len(result) == 2
|
||||
user_msg, ai_msg = result
|
||||
assert isinstance(user_msg, UserPromptMessage)
|
||||
assert user_msg.content == "My query"
|
||||
assert isinstance(ai_msg, AssistantPromptMessage)
|
||||
assert ai_msg.content == "My answer"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for get_history_prompt_text
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetHistoryPromptText:
|
||||
"""Tests for get_history_prompt_text."""
|
||||
|
||||
def _make_memory(self) -> TokenBufferMemory:
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
def test_empty_messages_returns_empty_string(self):
|
||||
mem = self._make_memory()
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[]):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert result == ""
|
||||
|
||||
def test_user_and_assistant_messages_formatted(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Hello"),
|
||||
AssistantPromptMessage(content="World"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A")
|
||||
assert result == "H: Hello\nA: World"
|
||||
|
||||
def test_custom_prefixes_applied(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Bye"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot")
|
||||
assert "Human: Hi" in result
|
||||
assert "Bot: Bye" in result
|
||||
|
||||
def test_list_content_with_text_and_image(self):
|
||||
"""List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image]."""
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="caption"),
|
||||
ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "caption" in result
|
||||
assert "[image]" in result
|
||||
|
||||
def test_list_content_text_only(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content=[TextPromptMessageContent(data="just text")]),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "just text" in result
|
||||
|
||||
def test_list_content_image_only(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "[image]" in result
|
||||
|
||||
def test_unknown_role_skipped(self):
|
||||
"""Messages with a role that is not USER or ASSISTANT are skipped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
# Create a mock message with a SYSTEM role
|
||||
system_msg = MagicMock()
|
||||
system_msg.role = PromptMessageRole.SYSTEM
|
||||
system_msg.content = "system instruction"
|
||||
|
||||
user_msg = UserPromptMessage(content="hi")
|
||||
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]):
|
||||
result = mem.get_history_prompt_text()
|
||||
|
||||
assert "system instruction" not in result
|
||||
assert "Human: hi" in result
|
||||
|
||||
def test_passes_max_token_limit_and_message_limit(self):
|
||||
"""Parameters are forwarded to get_history_prompt_messages."""
|
||||
mem = self._make_memory()
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get:
|
||||
mem.get_history_prompt_text(max_token_limit=500, message_limit=10)
|
||||
mock_get.assert_called_once_with(max_token_limit=500, message_limit=10)
|
||||
|
||||
def test_multiple_messages_joined_by_newline(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Q1"),
|
||||
AssistantPromptMessage(content="A1"),
|
||||
UserPromptMessage(content="Q2"),
|
||||
AssistantPromptMessage(content="A2"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
lines = result.split("\n")
|
||||
assert len(lines) == 4
|
||||
assert lines[0] == "Human: Q1"
|
||||
assert lines[1] == "Assistant: A1"
|
||||
assert lines[2] == "Human: Q2"
|
||||
assert lines[3] == "Assistant: A2"
|
||||
|
||||
def test_assistant_list_content_formatted(self):
|
||||
"""AssistantPromptMessage with list content is also handled."""
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="response text"),
|
||||
ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "response text" in result
|
||||
assert "[image]" in result
|
||||
@ -0,0 +1,964 @@
|
||||
"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.callbacks.base_callback import (
|
||||
_TEXT_COLOR_MAPPING,
|
||||
Callback,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete implementation of the abstract Callback for testing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteCallback(Callback):
|
||||
"""A minimal concrete subclass that satisfies all abstract methods."""
|
||||
|
||||
def __init__(self, raise_error: bool = False):
|
||||
self.raise_error = raise_error
|
||||
# Track invocations
|
||||
self.before_invoke_calls: list[dict] = []
|
||||
self.new_chunk_calls: list[dict] = []
|
||||
self.after_invoke_calls: list[dict] = []
|
||||
self.invoke_error_calls: list[dict] = []
|
||||
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance,
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=True,
|
||||
user=None,
|
||||
):
|
||||
self.before_invoke_calls.append(
|
||||
{
|
||||
"llm_instance": llm_instance,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
# To cover the 'raise NotImplementedError()' in the base class
|
||||
try:
|
||||
super().on_before_invoke(
|
||||
llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance,
|
||||
chunk,
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=True,
|
||||
user=None,
|
||||
):
|
||||
self.new_chunk_calls.append(
|
||||
{
|
||||
"llm_instance": llm_instance,
|
||||
"chunk": chunk,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
try:
|
||||
super().on_new_chunk(
|
||||
llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance,
|
||||
result,
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=True,
|
||||
user=None,
|
||||
):
|
||||
self.after_invoke_calls.append(
|
||||
{
|
||||
"llm_instance": llm_instance,
|
||||
"result": result,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
try:
|
||||
super().on_after_invoke(
|
||||
llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance,
|
||||
ex,
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=True,
|
||||
user=None,
|
||||
):
|
||||
self.invoke_error_calls.append(
|
||||
{
|
||||
"llm_instance": llm_instance,
|
||||
"ex": ex,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
try:
|
||||
super().on_invoke_error(
|
||||
llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A subclass that deliberately leaves abstract methods un-implemented,
|
||||
# used to verify that instantiation raises TypeError.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for _TEXT_COLOR_MAPPING module-level constant
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestTextColorMapping:
|
||||
"""Tests for the module-level _TEXT_COLOR_MAPPING dictionary."""
|
||||
|
||||
def test_contains_all_expected_colors(self):
|
||||
expected_keys = {"blue", "yellow", "pink", "green", "red"}
|
||||
assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys
|
||||
|
||||
def test_blue_escape_code(self):
|
||||
assert _TEXT_COLOR_MAPPING["blue"] == "36;1"
|
||||
|
||||
def test_yellow_escape_code(self):
|
||||
assert _TEXT_COLOR_MAPPING["yellow"] == "33;1"
|
||||
|
||||
def test_pink_escape_code(self):
|
||||
assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200"
|
||||
|
||||
def test_green_escape_code(self):
|
||||
assert _TEXT_COLOR_MAPPING["green"] == "32;1"
|
||||
|
||||
def test_red_escape_code(self):
|
||||
assert _TEXT_COLOR_MAPPING["red"] == "31;1"
|
||||
|
||||
def test_mapping_is_dict(self):
|
||||
assert isinstance(_TEXT_COLOR_MAPPING, dict)
|
||||
|
||||
def test_all_values_are_strings(self):
|
||||
for key, value in _TEXT_COLOR_MAPPING.items():
|
||||
assert isinstance(value, str), f"Value for {key!r} should be str"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for the Callback ABC itself
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCallbackAbstract:
|
||||
"""Tests verifying Callback is a proper ABC."""
|
||||
|
||||
def test_cannot_instantiate_abstract_class_directly(self):
|
||||
"""Callback cannot be instantiated since it has abstract methods."""
|
||||
with pytest.raises(TypeError):
|
||||
Callback() # type: ignore[abstract]
|
||||
|
||||
def test_concrete_subclass_can_be_instantiated(self):
|
||||
cb = ConcreteCallback()
|
||||
assert isinstance(cb, Callback)
|
||||
|
||||
def test_default_raise_error_is_false(self):
|
||||
cb = ConcreteCallback()
|
||||
assert cb.raise_error is False
|
||||
|
||||
def test_raise_error_can_be_set_to_true(self):
|
||||
cb = ConcreteCallback(raise_error=True)
|
||||
assert cb.raise_error is True
|
||||
|
||||
def test_subclass_missing_on_before_invoke_raises_type_error(self):
|
||||
"""A subclass missing any single abstract method cannot be instantiated."""
|
||||
|
||||
class IncompleteCallback(Callback):
|
||||
def on_new_chunk(self, *a, **kw): ...
|
||||
def on_after_invoke(self, *a, **kw): ...
|
||||
def on_invoke_error(self, *a, **kw): ...
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteCallback() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_on_new_chunk_raises_type_error(self):
|
||||
class IncompleteCallback(Callback):
|
||||
def on_before_invoke(self, *a, **kw): ...
|
||||
def on_after_invoke(self, *a, **kw): ...
|
||||
def on_invoke_error(self, *a, **kw): ...
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteCallback() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_on_after_invoke_raises_type_error(self):
|
||||
class IncompleteCallback(Callback):
|
||||
def on_before_invoke(self, *a, **kw): ...
|
||||
def on_new_chunk(self, *a, **kw): ...
|
||||
def on_invoke_error(self, *a, **kw): ...
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteCallback() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_on_invoke_error_raises_type_error(self):
|
||||
class IncompleteCallback(Callback):
|
||||
def on_before_invoke(self, *a, **kw): ...
|
||||
def on_new_chunk(self, *a, **kw): ...
|
||||
def on_after_invoke(self, *a, **kw): ...
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteCallback() # type: ignore[abstract]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_before_invoke
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnBeforeInvoke:
|
||||
"""Tests for the on_before_invoke callback method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
self.llm_instance = MagicMock()
|
||||
self.model = "gpt-4"
|
||||
self.credentials = {"api_key": "sk-test"}
|
||||
self.prompt_messages = [MagicMock(spec=PromptMessage)]
|
||||
self.model_parameters = {"temperature": 0.7}
|
||||
|
||||
def test_on_before_invoke_called_with_required_args(self):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.before_invoke_calls) == 1
|
||||
call = self.cb.before_invoke_calls[0]
|
||||
assert call["llm_instance"] is self.llm_instance
|
||||
assert call["model"] == self.model
|
||||
assert call["credentials"] == self.credentials
|
||||
assert call["prompt_messages"] is self.prompt_messages
|
||||
assert call["model_parameters"] is self.model_parameters
|
||||
|
||||
def test_on_before_invoke_defaults_tools_none(self):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.before_invoke_calls[0]["tools"] is None
|
||||
|
||||
def test_on_before_invoke_defaults_stop_none(self):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.before_invoke_calls[0]["stop"] is None
|
||||
|
||||
def test_on_before_invoke_defaults_stream_true(self):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.before_invoke_calls[0]["stream"] is True
|
||||
|
||||
def test_on_before_invoke_defaults_user_none(self):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.before_invoke_calls[0]["user"] is None
|
||||
|
||||
def test_on_before_invoke_with_all_optional_args(self):
|
||||
tools = [MagicMock(spec=PromptMessageTool)]
|
||||
stop = ["stop1", "stop2"]
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user="user-123",
|
||||
)
|
||||
call = self.cb.before_invoke_calls[0]
|
||||
assert call["tools"] is tools
|
||||
assert call["stop"] == stop
|
||||
assert call["stream"] is False
|
||||
assert call["user"] == "user-123"
|
||||
|
||||
def test_on_before_invoke_called_multiple_times(self):
|
||||
for i in range(3):
|
||||
self.cb.on_before_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
model=f"model-{i}",
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.before_invoke_calls) == 3
|
||||
assert self.cb.before_invoke_calls[2]["model"] == "model-2"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_new_chunk
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnNewChunk:
|
||||
"""Tests for the on_new_chunk callback method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
self.llm_instance = MagicMock()
|
||||
self.chunk = MagicMock(spec=LLMResultChunk)
|
||||
self.model = "gpt-3.5-turbo"
|
||||
self.credentials = {"api_key": "sk-test"}
|
||||
self.prompt_messages = [MagicMock(spec=PromptMessage)]
|
||||
self.model_parameters = {"max_tokens": 256}
|
||||
|
||||
def test_on_new_chunk_called_with_required_args(self):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.new_chunk_calls) == 1
|
||||
call = self.cb.new_chunk_calls[0]
|
||||
assert call["llm_instance"] is self.llm_instance
|
||||
assert call["chunk"] is self.chunk
|
||||
assert call["model"] == self.model
|
||||
assert call["credentials"] == self.credentials
|
||||
|
||||
def test_on_new_chunk_defaults_tools_none(self):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.new_chunk_calls[0]["tools"] is None
|
||||
|
||||
def test_on_new_chunk_defaults_stop_none(self):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.new_chunk_calls[0]["stop"] is None
|
||||
|
||||
def test_on_new_chunk_defaults_stream_true(self):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.new_chunk_calls[0]["stream"] is True
|
||||
|
||||
def test_on_new_chunk_defaults_user_none(self):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.new_chunk_calls[0]["user"] is None
|
||||
|
||||
def test_on_new_chunk_with_all_optional_args(self):
|
||||
tools = [MagicMock(spec=PromptMessageTool)]
|
||||
stop = ["END"]
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user="chunk-user",
|
||||
)
|
||||
call = self.cb.new_chunk_calls[0]
|
||||
assert call["tools"] is tools
|
||||
assert call["stop"] == stop
|
||||
assert call["stream"] is False
|
||||
assert call["user"] == "chunk-user"
|
||||
|
||||
def test_on_new_chunk_called_multiple_times(self):
|
||||
for i in range(5):
|
||||
self.cb.on_new_chunk(
|
||||
llm_instance=self.llm_instance,
|
||||
chunk=self.chunk,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.new_chunk_calls) == 5
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_after_invoke
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnAfterInvoke:
|
||||
"""Tests for the on_after_invoke callback method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
self.llm_instance = MagicMock()
|
||||
self.result = MagicMock(spec=LLMResult)
|
||||
self.model = "claude-3"
|
||||
self.credentials = {"api_key": "anthropic-key"}
|
||||
self.prompt_messages = [MagicMock(spec=PromptMessage)]
|
||||
self.model_parameters = {"temperature": 1.0}
|
||||
|
||||
def test_on_after_invoke_called_with_required_args(self):
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.after_invoke_calls) == 1
|
||||
call = self.cb.after_invoke_calls[0]
|
||||
assert call["llm_instance"] is self.llm_instance
|
||||
assert call["result"] is self.result
|
||||
assert call["model"] == self.model
|
||||
assert call["credentials"] is self.credentials
|
||||
|
||||
def test_on_after_invoke_defaults_tools_none(self):
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.after_invoke_calls[0]["tools"] is None
|
||||
|
||||
def test_on_after_invoke_defaults_stop_none(self):
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.after_invoke_calls[0]["stop"] is None
|
||||
|
||||
def test_on_after_invoke_defaults_stream_true(self):
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.after_invoke_calls[0]["stream"] is True
|
||||
|
||||
def test_on_after_invoke_defaults_user_none(self):
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.after_invoke_calls[0]["user"] is None
|
||||
|
||||
def test_on_after_invoke_with_all_optional_args(self):
|
||||
tools = [MagicMock(spec=PromptMessageTool)]
|
||||
stop = ["STOP"]
|
||||
self.cb.on_after_invoke(
|
||||
llm_instance=self.llm_instance,
|
||||
result=self.result,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user="after-user",
|
||||
)
|
||||
call = self.cb.after_invoke_calls[0]
|
||||
assert call["tools"] is tools
|
||||
assert call["stop"] == stop
|
||||
assert call["stream"] is False
|
||||
assert call["user"] == "after-user"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_invoke_error
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnInvokeError:
|
||||
"""Tests for the on_invoke_error callback method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
self.llm_instance = MagicMock()
|
||||
self.ex = ValueError("something went wrong")
|
||||
self.model = "gemini-pro"
|
||||
self.credentials = {"api_key": "google-key"}
|
||||
self.prompt_messages = [MagicMock(spec=PromptMessage)]
|
||||
self.model_parameters = {"top_p": 0.9}
|
||||
|
||||
def test_on_invoke_error_called_with_required_args(self):
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.invoke_error_calls) == 1
|
||||
call = self.cb.invoke_error_calls[0]
|
||||
assert call["llm_instance"] is self.llm_instance
|
||||
assert call["ex"] is self.ex
|
||||
assert call["model"] == self.model
|
||||
assert call["credentials"] is self.credentials
|
||||
|
||||
def test_on_invoke_error_defaults_tools_none(self):
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.invoke_error_calls[0]["tools"] is None
|
||||
|
||||
def test_on_invoke_error_defaults_stop_none(self):
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.invoke_error_calls[0]["stop"] is None
|
||||
|
||||
def test_on_invoke_error_defaults_stream_true(self):
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.invoke_error_calls[0]["stream"] is True
|
||||
|
||||
def test_on_invoke_error_defaults_user_none(self):
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert self.cb.invoke_error_calls[0]["user"] is None
|
||||
|
||||
def test_on_invoke_error_with_all_optional_args(self):
|
||||
tools = [MagicMock(spec=PromptMessageTool)]
|
||||
stop = ["HALT"]
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=self.ex,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user="error-user",
|
||||
)
|
||||
call = self.cb.invoke_error_calls[0]
|
||||
assert call["tools"] is tools
|
||||
assert call["stop"] == stop
|
||||
assert call["stream"] is False
|
||||
assert call["user"] == "error-user"
|
||||
|
||||
def test_on_invoke_error_accepts_various_exception_types(self):
|
||||
for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]:
|
||||
self.cb.on_invoke_error(
|
||||
llm_instance=self.llm_instance,
|
||||
ex=exc,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=self.prompt_messages,
|
||||
model_parameters=self.model_parameters,
|
||||
)
|
||||
assert len(self.cb.invoke_error_calls) == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for print_text (concrete method on Callback)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPrintText:
|
||||
"""Tests for the concrete print_text method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
|
||||
def test_print_text_without_color_prints_plain_text(self, capsys):
|
||||
self.cb.print_text("hello world")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "hello world"
|
||||
|
||||
def test_print_text_with_color_prints_colored_text(self, capsys):
|
||||
self.cb.print_text("colored text", color="blue")
|
||||
captured = capsys.readouterr()
|
||||
# Should contain ANSI escape sequences
|
||||
assert "colored text" in captured.out
|
||||
assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out
|
||||
|
||||
def test_print_text_without_color_no_ansi(self, capsys):
|
||||
self.cb.print_text("plain text", color=None)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "plain text"
|
||||
# No ANSI escape sequences
|
||||
assert "\x1b" not in captured.out
|
||||
|
||||
def test_print_text_default_end_is_empty_string(self, capsys):
|
||||
self.cb.print_text("no newline")
|
||||
captured = capsys.readouterr()
|
||||
assert not captured.out.endswith("\n")
|
||||
|
||||
def test_print_text_with_custom_end(self, capsys):
|
||||
self.cb.print_text("with newline", end="\n")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.endswith("\n")
|
||||
|
||||
def test_print_text_with_empty_string(self, capsys):
|
||||
self.cb.print_text("", color=None)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == ""
|
||||
|
||||
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
|
||||
def test_print_text_all_colors_work(self, color, capsys):
|
||||
"""Verify no KeyError is thrown for any valid color."""
|
||||
self.cb.print_text("test", color=color)
|
||||
captured = capsys.readouterr()
|
||||
assert "test" in captured.out
|
||||
|
||||
def test_print_text_calls_get_colored_text_when_color_given(self):
|
||||
with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct:
|
||||
with patch("builtins.print") as mock_print:
|
||||
self.cb.print_text("hello", color="green")
|
||||
mock_gct.assert_called_once_with("hello", "green")
|
||||
mock_print.assert_called_once_with("[COLORED]", end="")
|
||||
|
||||
def test_print_text_does_not_call_get_colored_text_when_no_color(self):
|
||||
with patch.object(self.cb, "_get_colored_text") as mock_gct:
|
||||
with patch("builtins.print"):
|
||||
self.cb.print_text("hello", color=None)
|
||||
mock_gct.assert_not_called()
|
||||
|
||||
def test_print_text_passes_end_to_print(self):
|
||||
with patch("builtins.print") as mock_print:
|
||||
self.cb.print_text("text", end="---")
|
||||
mock_print.assert_called_once_with("text", end="---")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for _get_colored_text (private helper method)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetColoredText:
|
||||
"""Tests for the _get_colored_text private method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.cb = ConcreteCallback()
|
||||
|
||||
@pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items()))
|
||||
def test_get_colored_text_uses_correct_escape_code(self, color, expected_code):
|
||||
result = self.cb._get_colored_text("text", color)
|
||||
assert expected_code in result
|
||||
|
||||
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
|
||||
def test_get_colored_text_contains_input_text(self, color):
|
||||
result = self.cb._get_colored_text("hello", color)
|
||||
assert "hello" in result
|
||||
|
||||
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
|
||||
def test_get_colored_text_starts_with_escape(self, color):
|
||||
result = self.cb._get_colored_text("text", color)
|
||||
# Should start with an ANSI escape (\x1b or \u001b)
|
||||
assert result.startswith("\x1b[") or result.startswith("\u001b[")
|
||||
|
||||
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
|
||||
def test_get_colored_text_ends_with_reset(self, color):
|
||||
result = self.cb._get_colored_text("text", color)
|
||||
# Should end with the ANSI reset code
|
||||
assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m")
|
||||
|
||||
def test_get_colored_text_returns_string(self):
|
||||
result = self.cb._get_colored_text("text", "blue")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_get_colored_text_blue_exact_format(self):
|
||||
result = self.cb._get_colored_text("hello", "blue")
|
||||
expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m"
|
||||
assert result == expected
|
||||
|
||||
def test_get_colored_text_red_exact_format(self):
|
||||
result = self.cb._get_colored_text("error", "red")
|
||||
expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m"
|
||||
assert result == expected
|
||||
|
||||
def test_get_colored_text_green_exact_format(self):
|
||||
result = self.cb._get_colored_text("ok", "green")
|
||||
expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m"
|
||||
assert result == expected
|
||||
|
||||
def test_get_colored_text_yellow_exact_format(self):
|
||||
result = self.cb._get_colored_text("warn", "yellow")
|
||||
expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m"
|
||||
assert result == expected
|
||||
|
||||
def test_get_colored_text_pink_exact_format(self):
|
||||
result = self.cb._get_colored_text("info", "pink")
|
||||
expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m"
|
||||
assert result == expected
|
||||
|
||||
def test_get_colored_text_empty_string(self):
|
||||
result = self.cb._get_colored_text("", "blue")
|
||||
assert isinstance(result, str)
|
||||
# Empty text should still have escape codes
|
||||
assert _TEXT_COLOR_MAPPING["blue"] in result
|
||||
|
||||
def test_get_colored_text_invalid_color_raises_key_error(self):
|
||||
with pytest.raises(KeyError):
|
||||
self.cb._get_colored_text("text", "purple")
|
||||
|
||||
def test_get_colored_text_with_special_characters(self):
|
||||
special = "hello\nworld\ttab"
|
||||
result = self.cb._get_colored_text(special, "blue")
|
||||
assert special in result
|
||||
|
||||
def test_get_colored_text_with_long_text(self):
|
||||
long_text = "a" * 10000
|
||||
result = self.cb._get_colored_text(long_text, "green")
|
||||
assert long_text in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Integration-style tests: full workflow through a ConcreteCallback
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestConcreteCallbackIntegration:
|
||||
"""End-to-end workflow tests using ConcreteCallback."""
|
||||
|
||||
def test_full_invocation_lifecycle(self):
|
||||
"""Simulate a complete LLM invocation lifecycle through all callbacks."""
|
||||
cb = ConcreteCallback()
|
||||
llm_instance = MagicMock()
|
||||
model = "gpt-4o"
|
||||
credentials = {"api_key": "sk-xyz"}
|
||||
prompt_messages = [MagicMock(spec=PromptMessage)]
|
||||
model_parameters = {"temperature": 0.5}
|
||||
tools = [MagicMock(spec=PromptMessageTool)]
|
||||
stop = ["<END>"]
|
||||
user = "user-abc"
|
||||
|
||||
# 1. Before invoke
|
||||
cb.on_before_invoke(
|
||||
llm_instance=llm_instance,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# 2. Multiple chunks during streaming
|
||||
for i in range(3):
|
||||
chunk = MagicMock(spec=LLMResultChunk)
|
||||
cb.on_new_chunk(
|
||||
llm_instance=llm_instance,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# 3. After invoke
|
||||
result = MagicMock(spec=LLMResult)
|
||||
cb.on_after_invoke(
|
||||
llm_instance=llm_instance,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert len(cb.before_invoke_calls) == 1
|
||||
assert len(cb.new_chunk_calls) == 3
|
||||
assert len(cb.after_invoke_calls) == 1
|
||||
assert len(cb.invoke_error_calls) == 0
|
||||
|
||||
def test_error_lifecycle(self):
|
||||
"""Simulate an invoke that results in an error."""
|
||||
cb = ConcreteCallback()
|
||||
llm_instance = MagicMock()
|
||||
model = "gpt-4"
|
||||
credentials = {}
|
||||
prompt_messages = []
|
||||
model_parameters = {}
|
||||
|
||||
cb.on_before_invoke(
|
||||
llm_instance=llm_instance,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
|
||||
ex = RuntimeError("API timeout")
|
||||
cb.on_invoke_error(
|
||||
llm_instance=llm_instance,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
|
||||
assert len(cb.before_invoke_calls) == 1
|
||||
assert len(cb.invoke_error_calls) == 1
|
||||
assert cb.invoke_error_calls[0]["ex"] is ex
|
||||
assert len(cb.after_invoke_calls) == 0
|
||||
|
||||
def test_print_text_with_color_in_integration(self, capsys):
|
||||
"""verify print_text works correctly in a concrete instance."""
|
||||
cb = ConcreteCallback()
|
||||
cb.print_text("SUCCESS", color="green", end="\n")
|
||||
captured = capsys.readouterr()
|
||||
assert "SUCCESS" in captured.out
|
||||
assert "\n" in captured.out
|
||||
|
||||
def test_print_text_no_color_in_integration(self, capsys):
|
||||
cb = ConcreteCallback()
|
||||
cb.print_text("plain output")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "plain output"
|
||||
@ -0,0 +1,700 @@
|
||||
"""
|
||||
Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py
|
||||
|
||||
Coverage targets:
|
||||
- LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream,
|
||||
prompt_message.name, model_parameters)
|
||||
- LoggingCallback.on_new_chunk (writes to stdout)
|
||||
- LoggingCallback.on_after_invoke (all branches: tool_calls present / absent)
|
||||
- LoggingCallback.on_invoke_error (logs exception via logger.exception)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_usage() -> LLMUsage:
|
||||
"""Return a minimal LLMUsage instance."""
|
||||
return LLMUsage(
|
||||
prompt_tokens=10,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.01"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.002"),
|
||||
completion_price=Decimal("0.04"),
|
||||
total_tokens=30,
|
||||
total_price=Decimal("0.05"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
|
||||
def _make_llm_result(
|
||||
content: str = "hello world",
|
||||
tool_calls: list | None = None,
|
||||
model: str = "gpt-4",
|
||||
system_fingerprint: str | None = "fp-abc",
|
||||
) -> LLMResult:
|
||||
"""Return an LLMResult with an AssistantPromptMessage."""
|
||||
assistant_msg = AssistantPromptMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
return LLMResult(
|
||||
model=model,
|
||||
message=assistant_msg,
|
||||
usage=_make_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(content: str = "chunk-text") -> LLMResultChunk:
|
||||
"""Return a minimal LLMResultChunk."""
|
||||
return LLMResultChunk(
|
||||
model="gpt-4",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=content),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage:
|
||||
return UserPromptMessage(content=content, name=name)
|
||||
|
||||
|
||||
def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage:
|
||||
return SystemPromptMessage(content=content)
|
||||
|
||||
|
||||
def _make_tool(name: str = "my_tool") -> PromptMessageTool:
|
||||
return PromptMessageTool(name=name, description="A tool", parameters={})
|
||||
|
||||
|
||||
def _make_tool_call(
|
||||
call_id: str = "call-1",
|
||||
func_name: str = "some_func",
|
||||
arguments: str = '{"key": "value"}',
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
return AssistantPromptMessage.ToolCall(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: shared LoggingCallback instance (no heavy state)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cb() -> LoggingCallback:
|
||||
return LoggingCallback()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_instance() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_before_invoke
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnBeforeInvoke:
|
||||
"""Tests for LoggingCallback.on_before_invoke."""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
cb: LoggingCallback,
|
||||
llm_instance: MagicMock,
|
||||
*,
|
||||
model: str = "gpt-4",
|
||||
credentials: dict | None = None,
|
||||
prompt_messages: list | None = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
cb.on_before_invoke(
|
||||
llm_instance=llm_instance,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
prompt_messages=prompt_messages or [],
|
||||
model_parameters=model_parameters or {},
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Calling with bare-minimum args should not raise."""
|
||||
self._invoke(cb, llm_instance)
|
||||
|
||||
def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""The model name must appear in print_text calls."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, model="claude-3")
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "claude-3" in calls_text
|
||||
|
||||
def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Each key-value pair of model_parameters must be printed."""
|
||||
params = {"temperature": 0.7, "max_tokens": 512}
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, model_parameters=params)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "temperature" in calls_text
|
||||
assert "0.7" in calls_text
|
||||
assert "max_tokens" in calls_text
|
||||
assert "512" in calls_text
|
||||
|
||||
def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Empty model_parameters dict should not raise."""
|
||||
self._invoke(cb, llm_instance, model_parameters={})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stop branch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""stop words must appear in output when provided."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, stop=["STOP", "END"])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "stop" in calls_text
|
||||
|
||||
def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When stop=None the stop line must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, stop=None)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "\tstop:" not in calls_text
|
||||
|
||||
def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When stop=[] (falsy) the stop line must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, stop=[])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "\tstop:" not in calls_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# tools branch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Tool names must appear in output when tools are provided."""
|
||||
tools = [_make_tool("search"), _make_tool("calculate")]
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, tools=tools)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "search" in calls_text
|
||||
assert "calculate" in calls_text
|
||||
|
||||
def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When tools=None the Tools section must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, tools=None)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "Tools:" not in calls_text
|
||||
|
||||
def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When tools=[] (falsy) the Tools section must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, tools=[])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "Tools:" not in calls_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# user branch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""User string must appear in output when provided."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, user="alice")
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "alice" in calls_text
|
||||
|
||||
def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When user=None the User line must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, user=None)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "User:" not in calls_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stream branch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When stream=True the [on_llm_new_chunk] marker must be printed."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, stream=True)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "[on_llm_new_chunk]" in calls_text
|
||||
|
||||
def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When stream=False the [on_llm_new_chunk] marker must NOT appear."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, stream=False)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "[on_llm_new_chunk]" not in calls_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# prompt_messages branch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When a PromptMessage has a name it must be printed."""
|
||||
msg = _make_user_prompt("hi", name="bob")
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, prompt_messages=[msg])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "bob" in calls_text
|
||||
|
||||
def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When a PromptMessage has no name the name line must NOT appear."""
|
||||
msg = _make_user_prompt("hi", name=None)
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, prompt_messages=[msg])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "\tname:" not in calls_text
|
||||
|
||||
def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Role and content of each PromptMessage must appear in output."""
|
||||
msg = _make_system_prompt("Be concise.")
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, prompt_messages=[msg])
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "system" in calls_text
|
||||
assert "Be concise." in calls_text
|
||||
|
||||
def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""All entries in prompt_messages are iterated and printed."""
|
||||
msgs = [
|
||||
_make_system_prompt("sys"),
|
||||
_make_user_prompt("user msg"),
|
||||
]
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, prompt_messages=msgs)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "sys" in calls_text
|
||||
assert "user msg" in calls_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Combination: everything provided
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Supply stop, tools, user, multiple params, named message – no exception."""
|
||||
msgs = [_make_user_prompt("question", name="alice")]
|
||||
tools = [_make_tool("tool_a")]
|
||||
with patch.object(cb, "print_text"):
|
||||
self._invoke(
|
||||
cb,
|
||||
llm_instance,
|
||||
model="gpt-3.5",
|
||||
model_parameters={"temperature": 1.0, "top_p": 0.9},
|
||||
tools=tools,
|
||||
stop=["DONE"],
|
||||
stream=True,
|
||||
user="alice",
|
||||
prompt_messages=msgs,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_new_chunk
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnNewChunk:
|
||||
"""Tests for LoggingCallback.on_new_chunk."""
|
||||
|
||||
def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""on_new_chunk must write the chunk's text content to sys.stdout."""
|
||||
chunk = _make_chunk("hello from LLM")
|
||||
written = []
|
||||
|
||||
with patch("sys.stdout") as mock_stdout:
|
||||
mock_stdout.write.side_effect = written.append
|
||||
cb.on_new_chunk(
|
||||
llm_instance=llm_instance,
|
||||
chunk=chunk,
|
||||
model="gpt-4",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
)
|
||||
mock_stdout.write.assert_called_once_with("hello from LLM")
|
||||
mock_stdout.flush.assert_called_once()
|
||||
|
||||
def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Works correctly even when the chunk content is an empty string."""
|
||||
chunk = _make_chunk("")
|
||||
with patch("sys.stdout") as mock_stdout:
|
||||
cb.on_new_chunk(
|
||||
llm_instance=llm_instance,
|
||||
chunk=chunk,
|
||||
model="gpt-4",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
)
|
||||
mock_stdout.write.assert_called_once_with("")
|
||||
mock_stdout.flush.assert_called_once()
|
||||
|
||||
def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""All optional parameters are accepted without errors."""
|
||||
chunk = _make_chunk("data")
|
||||
with patch("sys.stdout"):
|
||||
cb.on_new_chunk(
|
||||
llm_instance=llm_instance,
|
||||
chunk=chunk,
|
||||
model="gpt-4",
|
||||
credentials={"key": "secret"},
|
||||
prompt_messages=[_make_user_prompt("q")],
|
||||
model_parameters={"temperature": 0.5},
|
||||
tools=[_make_tool("t1")],
|
||||
stop=["EOS"],
|
||||
stream=True,
|
||||
user="bob",
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_after_invoke
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnAfterInvoke:
|
||||
"""Tests for LoggingCallback.on_after_invoke."""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
cb: LoggingCallback,
|
||||
llm_instance: MagicMock,
|
||||
result: LLMResult,
|
||||
**kwargs,
|
||||
):
|
||||
cb.on_after_invoke(
|
||||
llm_instance=llm_instance,
|
||||
result=result,
|
||||
model=result.model,
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""After-invoke header, content, model, usage, fingerprint must be printed."""
|
||||
result = _make_llm_result()
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "[on_llm_after_invoke]" in calls_text
|
||||
assert "hello world" in calls_text
|
||||
assert "gpt-4" in calls_text
|
||||
assert "fp-abc" in calls_text
|
||||
|
||||
def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When there are no tool_calls the 'Tool calls:' block must NOT appear."""
|
||||
result = _make_llm_result(tool_calls=[])
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "Tool calls:" not in calls_text
|
||||
|
||||
def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When tool_calls exist their id, name, and JSON arguments must be printed."""
|
||||
tc = _make_tool_call(
|
||||
call_id="call-xyz",
|
||||
func_name="fetch_data",
|
||||
arguments='{"url": "https://example.com"}',
|
||||
)
|
||||
result = _make_llm_result(tool_calls=[tc])
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "Tool calls:" in calls_text
|
||||
assert "call-xyz" in calls_text
|
||||
assert "fetch_data" in calls_text
|
||||
# arguments should be JSON-dumped
|
||||
assert "https://example.com" in calls_text
|
||||
|
||||
def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""All tool calls in the list must be iterated."""
|
||||
tcs = [
|
||||
_make_tool_call("id-1", "func_a", '{"a": 1}'),
|
||||
_make_tool_call("id-2", "func_b", '{"b": 2}'),
|
||||
]
|
||||
result = _make_llm_result(tool_calls=tcs)
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "id-1" in calls_text
|
||||
assert "func_a" in calls_text
|
||||
assert "id-2" in calls_text
|
||||
assert "func_b" in calls_text
|
||||
|
||||
def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""When system_fingerprint is None it should still be printed (as None)."""
|
||||
result = _make_llm_result(system_fingerprint=None)
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "System Fingerprint: None" in calls_text
|
||||
|
||||
def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""The usage object must appear in the printed output."""
|
||||
result = _make_llm_result()
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "Usage:" in calls_text
|
||||
|
||||
def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Verify json.dumps is applied to the arguments field (a string)."""
|
||||
raw_args = '{"x": 42}'
|
||||
tc = _make_tool_call(arguments=raw_args)
|
||||
result = _make_llm_result(tool_calls=[tc])
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
self._invoke(cb, llm_instance, result)
|
||||
|
||||
# Check if any call to print_text included the expected (json-encoded) arguments
|
||||
# json.dumps(raw_args) produces a string starting and ending with quotes
|
||||
expected_substring = json.dumps(raw_args)
|
||||
found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list)
|
||||
assert found, f"Expected {expected_substring} to be printed in one of the calls"
|
||||
|
||||
def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""All optional parameters should be accepted without error."""
|
||||
result = _make_llm_result()
|
||||
cb.on_after_invoke(
|
||||
llm_instance=llm_instance,
|
||||
result=result,
|
||||
model=result.model,
|
||||
credentials={"key": "secret"},
|
||||
prompt_messages=[_make_user_prompt("q")],
|
||||
model_parameters={"temperature": 0.9},
|
||||
tools=[_make_tool("t")],
|
||||
stop=["<EOS>"],
|
||||
stream=False,
|
||||
user="carol",
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for on_invoke_error
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnInvokeError:
|
||||
"""Tests for LoggingCallback.on_invoke_error."""
|
||||
|
||||
def _invoke_error(
|
||||
self,
|
||||
cb: LoggingCallback,
|
||||
llm_instance: MagicMock,
|
||||
ex: Exception,
|
||||
**kwargs,
|
||||
):
|
||||
cb.on_invoke_error(
|
||||
llm_instance=llm_instance,
|
||||
ex=ex,
|
||||
model="gpt-4",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""The [on_llm_invoke_error] banner must be printed."""
|
||||
with patch.object(cb, "print_text") as mock_print:
|
||||
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
|
||||
self._invoke_error(cb, llm_instance, RuntimeError("boom"))
|
||||
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
|
||||
assert "[on_llm_invoke_error]" in calls_text
|
||||
|
||||
def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""logger.exception must be called with the exception."""
|
||||
ex = ValueError("something went wrong")
|
||||
with patch.object(cb, "print_text"):
|
||||
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
|
||||
self._invoke_error(cb, llm_instance, ex)
|
||||
mock_logger.exception.assert_called_once_with(ex)
|
||||
|
||||
def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""Works with any exception type (TypeError, IOError, etc.)."""
|
||||
for exc_cls in (TypeError, IOError, KeyError, Exception):
|
||||
ex = exc_cls("error")
|
||||
with patch.object(cb, "print_text"):
|
||||
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
|
||||
self._invoke_error(cb, llm_instance, ex)
|
||||
mock_logger.exception.assert_called_once_with(ex)
|
||||
|
||||
def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock):
|
||||
"""All optional parameters should be accepted without error."""
|
||||
ex = RuntimeError("fail")
|
||||
with patch.object(cb, "print_text"):
|
||||
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger"):
|
||||
cb.on_invoke_error(
|
||||
llm_instance=llm_instance,
|
||||
ex=ex,
|
||||
model="gpt-4",
|
||||
credentials={"key": "secret"},
|
||||
prompt_messages=[_make_user_prompt("q")],
|
||||
model_parameters={"temperature": 0.7},
|
||||
tools=[_make_tool("t")],
|
||||
stop=["STOP"],
|
||||
stream=True,
|
||||
user="dave",
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for print_text (inherited from Callback, exercised through LoggingCallback)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPrintText:
|
||||
"""Verify that print_text from the Callback base class works correctly."""
|
||||
|
||||
def test_print_text_with_color(self, cb: LoggingCallback, capsys):
|
||||
"""print_text with a known colour should emit an ANSI escape sequence."""
|
||||
cb.print_text("hello", color="blue")
|
||||
captured = capsys.readouterr()
|
||||
assert "hello" in captured.out
|
||||
# ANSI escape codes should be present
|
||||
assert "\x1b[" in captured.out
|
||||
|
||||
def test_print_text_without_color(self, cb: LoggingCallback, capsys):
|
||||
"""print_text without colour should print plain text."""
|
||||
cb.print_text("plain text")
|
||||
captured = capsys.readouterr()
|
||||
assert "plain text" in captured.out
|
||||
|
||||
def test_print_text_all_colours(self, cb: LoggingCallback, capsys):
|
||||
"""Verify all supported colour keys don't raise."""
|
||||
for colour in ("blue", "yellow", "pink", "green", "red"):
|
||||
cb.print_text("x", color=colour)
|
||||
captured = capsys.readouterr()
|
||||
# All outputs should contain 'x' (5 calls)
|
||||
assert captured.out.count("x") >= 5
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Integration-style test: real print_text called (no mocking)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLoggingCallbackIntegration:
|
||||
"""Light integration tests – real print_text calls, just checking no exceptions."""
|
||||
|
||||
def test_on_before_invoke_full_run(self, capsys):
|
||||
"""Full on_before_invoke run with all optional fields – verifies real output."""
|
||||
cb = LoggingCallback()
|
||||
llm = MagicMock()
|
||||
msgs = [_make_user_prompt("Who are you?", name="tester")]
|
||||
tools = [_make_tool("calculator")]
|
||||
cb.on_before_invoke(
|
||||
llm_instance=llm,
|
||||
model="gpt-4-turbo",
|
||||
credentials={"api_key": "sk-xxx"},
|
||||
prompt_messages=msgs,
|
||||
model_parameters={"temperature": 0.8},
|
||||
tools=tools,
|
||||
stop=["STOP"],
|
||||
stream=True,
|
||||
user="test_user",
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "gpt-4-turbo" in captured.out
|
||||
assert "calculator" in captured.out
|
||||
assert "test_user" in captured.out
|
||||
assert "STOP" in captured.out
|
||||
assert "tester" in captured.out
|
||||
|
||||
def test_on_new_chunk_full_run(self, capsys):
|
||||
"""Full on_new_chunk run – verifies real stdout write."""
|
||||
cb = LoggingCallback()
|
||||
chunk = _make_chunk("streaming token")
|
||||
cb.on_new_chunk(
|
||||
llm_instance=MagicMock(),
|
||||
chunk=chunk,
|
||||
model="gpt-4",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "streaming token" in captured.out
|
||||
|
||||
def test_on_after_invoke_full_run_with_tool_calls(self, capsys):
|
||||
"""Full on_after_invoke run with tool calls – verifies real output."""
|
||||
cb = LoggingCallback()
|
||||
tc = _make_tool_call("call-99", "do_thing", '{"n": 5}')
|
||||
result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz")
|
||||
cb.on_after_invoke(
|
||||
llm_instance=MagicMock(),
|
||||
result=result,
|
||||
model=result.model,
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "result content" in captured.out
|
||||
assert "call-99" in captured.out
|
||||
assert "do_thing" in captured.out
|
||||
assert "fp-xyz" in captured.out
|
||||
|
||||
def test_on_invoke_error_full_run(self, capsys):
|
||||
"""Full on_invoke_error run – just verifies no exception is raised."""
|
||||
cb = LoggingCallback()
|
||||
ex = RuntimeError("something bad happened")
|
||||
# logger.exception writes to stderr; we just confirm it doesn't crash
|
||||
cb.on_invoke_error(
|
||||
llm_instance=MagicMock(),
|
||||
ex=ex,
|
||||
model="gpt-4",
|
||||
credentials={},
|
||||
prompt_messages=[],
|
||||
model_parameters={},
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "[on_llm_invoke_error]" in captured.out
|
||||
@ -0,0 +1,35 @@
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class TestI18nObject:
|
||||
def test_i18n_object_with_both_languages(self):
|
||||
"""
|
||||
Test I18nObject when both zh_Hans and en_US are provided.
|
||||
"""
|
||||
i18n = I18nObject(zh_Hans="你好", en_US="Hello")
|
||||
assert i18n.zh_Hans == "你好"
|
||||
assert i18n.en_US == "Hello"
|
||||
|
||||
def test_i18n_object_fallback_to_en_us(self):
|
||||
"""
|
||||
Test I18nObject when zh_Hans is missing, it should fallback to en_US.
|
||||
"""
|
||||
i18n = I18nObject(en_US="Hello")
|
||||
assert i18n.zh_Hans == "Hello"
|
||||
assert i18n.en_US == "Hello"
|
||||
|
||||
def test_i18n_object_with_none_zh_hans(self):
|
||||
"""
|
||||
Test I18nObject when zh_Hans is None, it should fallback to en_US.
|
||||
"""
|
||||
i18n = I18nObject(zh_Hans=None, en_US="Hello")
|
||||
assert i18n.zh_Hans == "Hello"
|
||||
assert i18n.en_US == "Hello"
|
||||
|
||||
def test_i18n_object_with_empty_zh_hans(self):
|
||||
"""
|
||||
Test I18nObject when zh_Hans is an empty string, it should fallback to en_US.
|
||||
"""
|
||||
i18n = I18nObject(zh_Hans="", en_US="Hello")
|
||||
assert i18n.zh_Hans == "Hello"
|
||||
assert i18n.en_US == "Hello"
|
||||
@ -0,0 +1,210 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptMessageRole:
|
||||
def test_value_of(self):
|
||||
assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM
|
||||
assert PromptMessageRole.value_of("user") == PromptMessageRole.USER
|
||||
assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT
|
||||
assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL
|
||||
|
||||
with pytest.raises(ValueError, match="invalid prompt message type value invalid"):
|
||||
PromptMessageRole.value_of("invalid")
|
||||
|
||||
|
||||
class TestPromptMessageEntities:
|
||||
def test_prompt_message_tool(self):
|
||||
tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"})
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "test desc"
|
||||
assert tool.parameters == {"foo": "bar"}
|
||||
|
||||
def test_prompt_message_function(self):
|
||||
tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"})
|
||||
func = PromptMessageFunction(function=tool)
|
||||
assert func.type == "function"
|
||||
assert func.function == tool
|
||||
|
||||
|
||||
class TestPromptMessageContent:
|
||||
def test_text_content(self):
|
||||
content = TextPromptMessageContent(data="hello")
|
||||
assert content.type == PromptMessageContentType.TEXT
|
||||
assert content.data == "hello"
|
||||
|
||||
def test_image_content(self):
|
||||
content = ImagePromptMessageContent(
|
||||
format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
)
|
||||
assert content.type == PromptMessageContentType.IMAGE
|
||||
assert content.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
assert content.data == "data:image/jpeg;base64,abc"
|
||||
|
||||
def test_image_content_url(self):
|
||||
content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg")
|
||||
assert content.data == "https://example.com/image.jpg"
|
||||
|
||||
def test_audio_content(self):
|
||||
content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg")
|
||||
assert content.type == PromptMessageContentType.AUDIO
|
||||
assert content.data == "data:audio/mpeg;base64,abc"
|
||||
|
||||
def test_video_content(self):
|
||||
content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4")
|
||||
assert content.type == PromptMessageContentType.VIDEO
|
||||
assert content.data == "data:video/mp4;base64,abc"
|
||||
|
||||
def test_document_content(self):
|
||||
content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf")
|
||||
assert content.type == PromptMessageContentType.DOCUMENT
|
||||
assert content.data == "data:application/pdf;base64,abc"
|
||||
|
||||
|
||||
class TestPromptMessages:
|
||||
def test_user_prompt_message(self):
|
||||
msg = UserPromptMessage(content="hello")
|
||||
assert msg.role == PromptMessageRole.USER
|
||||
assert msg.content == "hello"
|
||||
assert msg.is_empty() is False
|
||||
assert msg.get_text_content() == "hello"
|
||||
|
||||
def test_user_prompt_message_complex_content(self):
|
||||
content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")]
|
||||
msg = UserPromptMessage(content=content)
|
||||
assert msg.get_text_content() == "hello world"
|
||||
|
||||
# Test validation from dict
|
||||
msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}])
|
||||
assert isinstance(msg2.content[0], TextPromptMessageContent)
|
||||
assert msg2.content[0].data == "hi"
|
||||
|
||||
def test_prompt_message_empty(self):
|
||||
msg = UserPromptMessage(content=None)
|
||||
assert msg.is_empty() is True
|
||||
assert msg.get_text_content() == ""
|
||||
|
||||
def test_assistant_prompt_message(self):
|
||||
msg = AssistantPromptMessage(content="thinking...")
|
||||
assert msg.role == PromptMessageRole.ASSISTANT
|
||||
assert msg.is_empty() is False
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"),
|
||||
)
|
||||
msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call])
|
||||
assert msg_with_tools.is_empty() is False
|
||||
assert msg_with_tools.role == PromptMessageRole.ASSISTANT
|
||||
|
||||
def test_assistant_tool_call_id_transform(self):
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=123,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"),
|
||||
)
|
||||
assert tool_call.id == "123"
|
||||
|
||||
def test_system_prompt_message(self):
|
||||
msg = SystemPromptMessage(content="you are a bot")
|
||||
assert msg.role == PromptMessageRole.SYSTEM
|
||||
assert msg.content == "you are a bot"
|
||||
|
||||
def test_tool_prompt_message(self):
|
||||
# Case 1: Both content and tool_call_id are present
|
||||
msg = ToolPromptMessage(content="result", tool_call_id="call_1")
|
||||
assert msg.role == PromptMessageRole.TOOL
|
||||
assert msg.tool_call_id == "call_1"
|
||||
assert msg.is_empty() is False
|
||||
|
||||
# Case 2: Content is present, but tool_call_id is empty
|
||||
msg_content_only = ToolPromptMessage(content="result", tool_call_id="")
|
||||
assert msg_content_only.is_empty() is False
|
||||
|
||||
# Case 3: Content is None, but tool_call_id is present
|
||||
msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1")
|
||||
assert msg_id_only.is_empty() is False
|
||||
|
||||
# Case 4: Both content and tool_call_id are empty
|
||||
msg_empty = ToolPromptMessage(content=None, tool_call_id="")
|
||||
assert msg_empty.is_empty() is True
|
||||
|
||||
def test_prompt_message_validation_errors(self):
|
||||
with pytest.raises(KeyError):
|
||||
# Invalid content type in list
|
||||
UserPromptMessage(content=[{"type": "invalid", "data": "foo"}])
|
||||
|
||||
with pytest.raises(ValueError, match="invalid prompt message"):
|
||||
# Not a dict or PromptMessageContent
|
||||
UserPromptMessage(content=[123])
|
||||
|
||||
def test_prompt_message_serialization(self):
|
||||
# Case: content is None
|
||||
assert UserPromptMessage(content=None).serialize_content(None) is None
|
||||
|
||||
# Case: content is str
|
||||
assert UserPromptMessage(content="hello").serialize_content("hello") == "hello"
|
||||
|
||||
# Case: content is list of dict
|
||||
content_list = [{"type": "text", "data": "hi"}]
|
||||
msg = UserPromptMessage(content=content_list)
|
||||
assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}]
|
||||
|
||||
# Case: content is Sequence but not list (e.g. tuple)
|
||||
# To hit line 204, we can call serialize_content manually or
|
||||
# try to pass a type that pydantic doesn't convert to list in its internal state.
|
||||
# Actually, let's just call it manually on the instance.
|
||||
msg = UserPromptMessage(content="test")
|
||||
content_tuple = (TextPromptMessageContent(data="hi"),)
|
||||
assert msg.serialize_content(content_tuple) == content_tuple
|
||||
|
||||
def test_prompt_message_mixed_content_validation(self):
|
||||
# Test branch: isinstance(prompt, PromptMessageContent)
|
||||
# but not (TextPromptMessageContent | MultiModalPromptMessageContent)
|
||||
# Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||
|
||||
# We need a PromptMessageContent that is NOT Text or MultiModal.
|
||||
# But PromptMessageContentUnionTypes discriminator handles this usually.
|
||||
# We can bypass high-level validation by passing the object directly in a list.
|
||||
|
||||
class MockContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
mock_item = MockContent(data="test")
|
||||
msg = UserPromptMessage(content=[mock_item])
|
||||
# It should hit line 187 and convert to TextPromptMessageContent
|
||||
assert isinstance(msg.content[0], TextPromptMessageContent)
|
||||
assert msg.content[0].data == "test"
|
||||
|
||||
def test_prompt_message_get_text_content_branches(self):
|
||||
# content is None
|
||||
msg_none = UserPromptMessage(content=None)
|
||||
assert msg_none.get_text_content() == ""
|
||||
|
||||
# content is list but no text content
|
||||
image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg")
|
||||
msg_image = UserPromptMessage(content=[image])
|
||||
assert msg_image.get_text_content() == ""
|
||||
|
||||
# content is list with mixed
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
msg_mixed = UserPromptMessage(content=[text, image])
|
||||
assert msg_mixed.get_text_content() == "hello"
|
||||
@ -0,0 +1,220 @@
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ModelUsage,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
ProviderModel,
|
||||
)
|
||||
|
||||
|
||||
class TestModelType:
|
||||
def test_value_of(self):
|
||||
assert ModelType.value_of("text-generation") == ModelType.LLM
|
||||
assert ModelType.value_of(ModelType.LLM) == ModelType.LLM
|
||||
assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING
|
||||
assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING
|
||||
assert ModelType.value_of("reranking") == ModelType.RERANK
|
||||
assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK
|
||||
assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT
|
||||
assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT
|
||||
assert ModelType.value_of("tts") == ModelType.TTS
|
||||
assert ModelType.value_of(ModelType.TTS) == ModelType.TTS
|
||||
assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION
|
||||
|
||||
with pytest.raises(ValueError, match="invalid origin model type invalid"):
|
||||
ModelType.value_of("invalid")
|
||||
|
||||
def test_to_origin_model_type(self):
|
||||
assert ModelType.LLM.to_origin_model_type() == "text-generation"
|
||||
assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings"
|
||||
assert ModelType.RERANK.to_origin_model_type() == "reranking"
|
||||
assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text"
|
||||
assert ModelType.TTS.to_origin_model_type() == "tts"
|
||||
assert ModelType.MODERATION.to_origin_model_type() == "moderation"
|
||||
|
||||
# Testing the else branch in to_origin_model_type
|
||||
# Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it.
|
||||
# But if we look at the implementation:
|
||||
# if self == self.LLM: ... elif ... else: raise ValueError
|
||||
# We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise.
|
||||
# Actually, adding a new member to an enum at runtime is possible but messy.
|
||||
# Let's see if we can trigger it.
|
||||
|
||||
|
||||
class TestFetchFrom:
|
||||
def test_values(self):
|
||||
assert FetchFrom.PREDEFINED_MODEL == "predefined-model"
|
||||
assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model"
|
||||
|
||||
|
||||
class TestModelFeature:
|
||||
def test_values(self):
|
||||
assert ModelFeature.TOOL_CALL == "tool-call"
|
||||
assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call"
|
||||
assert ModelFeature.AGENT_THOUGHT == "agent-thought"
|
||||
assert ModelFeature.VISION == "vision"
|
||||
assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call"
|
||||
assert ModelFeature.DOCUMENT == "document"
|
||||
assert ModelFeature.VIDEO == "video"
|
||||
assert ModelFeature.AUDIO == "audio"
|
||||
assert ModelFeature.STRUCTURED_OUTPUT == "structured-output"
|
||||
|
||||
|
||||
class TestDefaultParameterName:
|
||||
def test_value_of(self):
|
||||
assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE
|
||||
assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P
|
||||
|
||||
with pytest.raises(ValueError, match="invalid parameter name invalid"):
|
||||
DefaultParameterName.value_of("invalid")
|
||||
|
||||
|
||||
class TestParameterType:
|
||||
def test_values(self):
|
||||
assert ParameterType.FLOAT == "float"
|
||||
assert ParameterType.INT == "int"
|
||||
assert ParameterType.STRING == "string"
|
||||
assert ParameterType.BOOLEAN == "boolean"
|
||||
assert ParameterType.TEXT == "text"
|
||||
|
||||
|
||||
class TestModelPropertyKey:
|
||||
def test_values(self):
|
||||
assert ModelPropertyKey.MODE == "mode"
|
||||
assert ModelPropertyKey.CONTEXT_SIZE == "context_size"
|
||||
|
||||
|
||||
class TestProviderModel:
|
||||
def test_provider_model(self):
|
||||
model = ProviderModel(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
)
|
||||
assert model.model == "gpt-4"
|
||||
assert model.support_structure_output is False
|
||||
|
||||
model_with_features = ProviderModel(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.STRUCTURED_OUTPUT],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
)
|
||||
assert model_with_features.support_structure_output is True
|
||||
|
||||
|
||||
class TestParameterRule:
|
||||
def test_parameter_rule(self):
|
||||
rule = ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=0.7,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
precision=2,
|
||||
)
|
||||
assert rule.name == "temperature"
|
||||
assert rule.default == 0.7
|
||||
|
||||
|
||||
class TestPriceConfig:
|
||||
def test_price_config(self):
|
||||
config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD")
|
||||
assert config.input == Decimal("0.01")
|
||||
assert config.output == Decimal("0.02")
|
||||
|
||||
|
||||
class TestAIModelEntity:
|
||||
def test_ai_model_entity_no_json_schema(self):
|
||||
entity = AIModelEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
parameter_rules=[
|
||||
ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT)
|
||||
],
|
||||
)
|
||||
assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or [])
|
||||
|
||||
def test_ai_model_entity_with_json_schema(self):
|
||||
# Case: json_schema in parameter rules, features is None
|
||||
entity = AIModelEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
parameter_rules=[
|
||||
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
|
||||
],
|
||||
)
|
||||
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
|
||||
|
||||
def test_ai_model_entity_with_json_schema_and_features_empty(self):
|
||||
# Case: json_schema in parameter rules, features is empty list
|
||||
entity = AIModelEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
parameter_rules=[
|
||||
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
|
||||
],
|
||||
)
|
||||
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
|
||||
|
||||
def test_ai_model_entity_with_json_schema_and_other_features(self):
|
||||
# Case: json_schema in parameter rules, features has other things
|
||||
entity = AIModelEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.VISION],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
|
||||
parameter_rules=[
|
||||
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
|
||||
],
|
||||
)
|
||||
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
|
||||
assert ModelFeature.VISION in entity.features
|
||||
|
||||
|
||||
class TestModelUsage:
|
||||
def test_model_usage(self):
|
||||
usage = ModelUsage()
|
||||
assert isinstance(usage, ModelUsage)
|
||||
|
||||
|
||||
class TestPriceType:
|
||||
def test_values(self):
|
||||
assert PriceType.INPUT == "input"
|
||||
assert PriceType.OUTPUT == "output"
|
||||
|
||||
|
||||
class TestPriceInfo:
|
||||
def test_price_info(self):
|
||||
info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD")
|
||||
assert info.total_amount == Decimal("0.05")
|
||||
@ -0,0 +1,63 @@
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class TestInvokeErrors:
|
||||
def test_invoke_error_with_description(self):
|
||||
error = InvokeError("Custom description")
|
||||
assert error.description == "Custom description"
|
||||
assert str(error) == "Custom description"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_invoke_error_without_description(self):
|
||||
error = InvokeError()
|
||||
assert error.description is None
|
||||
assert str(error) == "InvokeError"
|
||||
|
||||
def test_invoke_connection_error(self):
|
||||
# Now preserves class-level description
|
||||
error = InvokeConnectionError()
|
||||
assert error.description == "Connection Error"
|
||||
assert str(error) == "Connection Error"
|
||||
assert isinstance(error, InvokeError)
|
||||
|
||||
# Test with explicit description
|
||||
error_with_desc = InvokeConnectionError("Connection Error")
|
||||
assert error_with_desc.description == "Connection Error"
|
||||
assert str(error_with_desc) == "Connection Error"
|
||||
|
||||
def test_invoke_server_unavailable_error(self):
|
||||
error = InvokeServerUnavailableError()
|
||||
assert error.description == "Server Unavailable Error"
|
||||
assert str(error) == "Server Unavailable Error"
|
||||
assert isinstance(error, InvokeError)
|
||||
|
||||
def test_invoke_rate_limit_error(self):
|
||||
error = InvokeRateLimitError()
|
||||
assert error.description == "Rate Limit Error"
|
||||
assert str(error) == "Rate Limit Error"
|
||||
assert isinstance(error, InvokeError)
|
||||
|
||||
def test_invoke_authorization_error(self):
|
||||
error = InvokeAuthorizationError()
|
||||
assert error.description == "Incorrect model credentials provided, please check and try again. "
|
||||
assert str(error) == "Incorrect model credentials provided, please check and try again. "
|
||||
assert isinstance(error, InvokeError)
|
||||
|
||||
def test_invoke_bad_request_error(self):
|
||||
error = InvokeBadRequestError()
|
||||
assert error.description == "Bad Request Error"
|
||||
assert str(error) == "Bad Request Error"
|
||||
assert isinstance(error, InvokeError)
|
||||
|
||||
def test_invoke_error_inheritance(self):
|
||||
# Test that we can override the default description in subclasses
|
||||
error = InvokeBadRequestError("Overridden Error")
|
||||
assert error.description == "Overridden Error"
|
||||
assert str(error) == "Overridden Error"
|
||||
@ -0,0 +1,336 @@
|
||||
import decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TestAIModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def ai_model(self, mock_plugin_model_provider):
|
||||
return AIModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.LLM,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_invoke_error_mapping(self, ai_model):
|
||||
mapping = ai_model._invoke_error_mapping
|
||||
assert InvokeConnectionError in mapping
|
||||
assert InvokeServerUnavailableError in mapping
|
||||
assert InvokeRateLimitError in mapping
|
||||
assert InvokeAuthorizationError in mapping
|
||||
assert InvokeBadRequestError in mapping
|
||||
assert PluginDaemonInnerError in mapping
|
||||
assert ValueError in mapping
|
||||
|
||||
def test_transform_invoke_error(self, ai_model):
|
||||
# Case: mapped error (InvokeAuthorizationError)
|
||||
err = Exception("Original error")
|
||||
with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}):
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert isinstance(transformed, InvokeAuthorizationError)
|
||||
assert "Incorrect model credentials provided" in str(transformed.description)
|
||||
|
||||
# Case: mapped error (InvokeError subclass)
|
||||
with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}):
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert isinstance(transformed, InvokeError)
|
||||
assert "[test_provider]" in transformed.description
|
||||
|
||||
# Case: mapped error (not InvokeError)
|
||||
class CustomNonInvokeError(Exception):
|
||||
pass
|
||||
|
||||
with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}):
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert transformed == err
|
||||
|
||||
# Case: unmapped error
|
||||
unmapped_err = Exception("Unmapped")
|
||||
transformed = ai_model._transform_invoke_error(unmapped_err)
|
||||
assert isinstance(transformed, InvokeError)
|
||||
assert "Error: Unmapped" in transformed.description
|
||||
|
||||
def test_get_price(self, ai_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"key": "value"}
|
||||
|
||||
# Mock get_model_schema
|
||||
mock_schema = MagicMock(spec=AIModelEntity)
|
||||
mock_schema.pricing = PriceConfig(
|
||||
input=decimal.Decimal("0.002"),
|
||||
output=decimal.Decimal("0.004"),
|
||||
unit=decimal.Decimal(1000), # 1000 tokens per unit
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
# Test INPUT
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000)
|
||||
assert price_info.unit_price == decimal.Decimal("0.002")
|
||||
|
||||
# Test OUTPUT
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000)
|
||||
assert price_info.unit_price == decimal.Decimal("0.004")
|
||||
|
||||
# Case: unit_price is None (returns zeroed PriceInfo)
|
||||
mock_schema.pricing = None
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000)
|
||||
assert price_info.total_amount == decimal.Decimal("0.0")
|
||||
|
||||
def test_get_price_no_price_config_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
# We need it to be truthy at line 107 and 112 but falsy at line 127.
|
||||
class ChangingPriceConfig:
|
||||
def __init__(self):
|
||||
self.input = decimal.Decimal("0.01")
|
||||
self.unit = decimal.Decimal(1)
|
||||
self.currency = "USD"
|
||||
self.called = 0
|
||||
|
||||
def __bool__(self):
|
||||
self.called += 1
|
||||
return self.called <= 2
|
||||
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.pricing = ChangingPriceConfig()
|
||||
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ai_model.get_price(model_name, {}, PriceType.INPUT, 1000)
|
||||
assert "Price config not found" in str(excinfo.value)
|
||||
|
||||
def test_get_model_schema_cache_hit(self, ai_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = mock_schema.model_dump_json().encode()
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, credentials)
|
||||
|
||||
assert schema.model == "test_model"
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
def test_get_model_schema_cache_miss(self, ai_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = mock_schema
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, credentials)
|
||||
|
||||
assert schema == mock_schema
|
||||
mock_manager.get_model_schema.assert_called_once()
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_get_model_schema_redis_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.side_effect = RedisError("Connection refused")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_manager.get_model_schema.assert_called_once()
|
||||
|
||||
def test_get_model_schema_validation_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = b"invalid json"
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
# This should trigger ValidationError at line 166 and go to delete()
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_get_model_schema_redis_delete_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = b'{"invalid": "schema"}'
|
||||
mock_redis.delete.side_effect = RedisError("Delete failed")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_get_model_schema_redis_setex_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = RuntimeError("Setex failed")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = mock_schema
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema == mock_schema
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="invalid",
|
||||
use_template="invalid_template_name",
|
||||
label=I18nObject(en_US="Invalid"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
|
||||
assert schema.parameter_rules[0].use_template == "invalid_template_name"
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US=""),
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING),
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
|
||||
|
||||
assert schema.parameter_rules[0].max == 1.0
|
||||
assert schema.parameter_rules[1].help.en_US != ""
|
||||
assert schema.parameter_rules[2].help.zh_Hans != ""
|
||||
assert schema.parameter_rules[3].use_template is None
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials_none(self, ai_model):
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=None):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials("model", {})
|
||||
assert schema is None
|
||||
|
||||
def test_get_customizable_model_schema_default(self, ai_model):
|
||||
assert ai_model.get_customizable_model_schema("model", {}) is None
|
||||
|
||||
def test_get_default_parameter_rule_variable_map(self, ai_model):
|
||||
# Valid
|
||||
res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
|
||||
assert res["default"] == 0.0
|
||||
|
||||
# Invalid
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
ai_model._get_default_parameter_rule_variable_map("invalid_name")
|
||||
assert "Invalid model parameter rule name" in str(excinfo.value)
|
||||
@ -0,0 +1,476 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module
|
||||
|
||||
# Access large_language_model members via llm_module to avoid partial import issues in CI
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType, PriceInfo
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks
|
||||
|
||||
|
||||
def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage:
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1),
|
||||
prompt_price=Decimal(prompt_tokens) * Decimal("0.001"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1),
|
||||
completion_price=Decimal(completion_tokens) * Decimal("0.002"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
)
|
||||
|
||||
|
||||
def _tool_call_delta(
|
||||
*,
|
||||
tool_call_id: str,
|
||||
tool_type: str = "function",
|
||||
function_name: str = "",
|
||||
function_arguments: str = "",
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
return AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type=tool_type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments),
|
||||
)
|
||||
|
||||
|
||||
def _chunk(
|
||||
*,
|
||||
model: str = "test-model",
|
||||
content: str | list[Any] | None = None,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
system_fingerprint: str | None = None,
|
||||
) -> LLMResultChunk:
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpyCallback(Callback):
|
||||
raise_error: bool = False
|
||||
before: list[dict[str, Any]] = field(default_factory=list)
|
||||
new_chunk: list[dict[str, Any]] = field(default_factory=list)
|
||||
after: list[dict[str, Any]] = field(default_factory=list)
|
||||
error: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override]
|
||||
self.before.append(kwargs)
|
||||
|
||||
def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override]
|
||||
self.new_chunk.append(kwargs)
|
||||
|
||||
def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override]
|
||||
self.after.append(kwargs)
|
||||
|
||||
def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override]
|
||||
self.error.append(kwargs)
|
||||
|
||||
|
||||
class _TestLLM(llm_module.LargeLanguageModel):
|
||||
def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override]
|
||||
return PriceInfo(
|
||||
unit_price=Decimal("0.01"),
|
||||
unit=Decimal(1),
|
||||
total_amount=Decimal(tokens) * Decimal("0.01"),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override]
|
||||
return RuntimeError(f"transformed: {error}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> _TestLLM:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
return _TestLLM.model_construct(
|
||||
tenant_id="tenant",
|
||||
model_type=ModelType.LLM,
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
started_at=1.0,
|
||||
)
|
||||
|
||||
|
||||
def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123"))
|
||||
assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123"
|
||||
|
||||
|
||||
def test_run_callbacks_no_callbacks_noop() -> None:
|
||||
invoked: list[int] = []
|
||||
llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1))
|
||||
llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1))
|
||||
assert invoked == []
|
||||
|
||||
|
||||
def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None:
|
||||
class Boom:
|
||||
raise_error = False
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
llm_module._run_callbacks(
|
||||
[Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom"))
|
||||
)
|
||||
assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
def test_run_callbacks_reraises_when_raise_error_true() -> None:
|
||||
class Boom:
|
||||
raise_error = True
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
llm_module._run_callbacks(
|
||||
[Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom"))
|
||||
)
|
||||
|
||||
|
||||
def test_get_or_create_tool_call_empty_id_returns_last() -> None:
|
||||
calls = [
|
||||
_tool_call_delta(tool_call_id="id1", function_name="a"),
|
||||
_tool_call_delta(tool_call_id="id2", function_name="b"),
|
||||
]
|
||||
assert llm_module._get_or_create_tool_call(calls, "") is calls[-1]
|
||||
|
||||
|
||||
def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is empty"):
|
||||
llm_module._get_or_create_tool_call([], "")
|
||||
|
||||
|
||||
def test_get_or_create_tool_call_creates_if_missing() -> None:
|
||||
calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call = llm_module._get_or_create_tool_call(calls, "new-id")
|
||||
assert tool_call.id == "new-id"
|
||||
assert tool_call.function.name == ""
|
||||
assert tool_call.function.arguments == ""
|
||||
assert calls == [tool_call]
|
||||
|
||||
|
||||
def test_get_or_create_tool_call_returns_existing_when_found() -> None:
|
||||
existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}")
|
||||
calls = [existing]
|
||||
assert llm_module._get_or_create_tool_call(calls, "same-id") is existing
|
||||
|
||||
|
||||
def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None:
|
||||
tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{")
|
||||
delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}")
|
||||
llm_module._merge_tool_call_delta(tool_call, delta)
|
||||
assert tool_call.id == "id2"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "y"
|
||||
assert tool_call.function.arguments == "{}"
|
||||
|
||||
|
||||
def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed"))
|
||||
delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{")
|
||||
existing: list[AssistantPromptMessage.ToolCall] = []
|
||||
llm_module._increase_tool_call([delta], existing)
|
||||
assert len(existing) == 1
|
||||
assert existing[0].id == "chatcmpl-tool-fixed"
|
||||
assert existing[0].function.name == "fn"
|
||||
assert existing[0].function.arguments == "{"
|
||||
|
||||
|
||||
def test_increase_tool_call_merges_incremental_arguments() -> None:
|
||||
existing: list[AssistantPromptMessage.ToolCall] = []
|
||||
llm_module._increase_tool_call(
|
||||
[_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing
|
||||
)
|
||||
llm_module._increase_tool_call(
|
||||
[_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing
|
||||
)
|
||||
assert len(existing) == 1
|
||||
assert existing[0].function.name == "fn"
|
||||
assert existing[0].function.arguments == "{}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "expected_type"),
|
||||
[
|
||||
("hello", str),
|
||||
([TextPromptMessageContent(data="hello")], list),
|
||||
],
|
||||
)
|
||||
def test_build_llm_result_from_chunks_accumulates_and_raises_error(
|
||||
content: str | list[TextPromptMessageContent],
|
||||
expected_type: type,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain"))
|
||||
caplog.set_level(logging.DEBUG)
|
||||
|
||||
tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}")
|
||||
first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1")
|
||||
|
||||
def iter_with_error() -> Iterator[LLMResultChunk]:
|
||||
yield first
|
||||
raise RuntimeError("drain boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="drain boom"):
|
||||
_build_llm_result_from_chunks(
|
||||
model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error()
|
||||
)
|
||||
|
||||
assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
def test_build_llm_result_from_chunks_empty_iterator() -> None:
|
||||
def empty() -> Iterator[LLMResultChunk]:
|
||||
if False: # pragma: no cover
|
||||
yield _chunk()
|
||||
return
|
||||
|
||||
result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty())
|
||||
assert result.message.content == []
|
||||
assert result.usage.total_tokens == 0
|
||||
assert result.system_fingerprint is None
|
||||
|
||||
|
||||
def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None:
|
||||
chunks = iter([_chunk(content="first"), _chunk(content="second")])
|
||||
result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks)
|
||||
assert result.message.content == "firstsecond"
|
||||
|
||||
|
||||
def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
invoked: dict[str, Any] = {}
|
||||
|
||||
class FakePluginModelClient:
|
||||
def invoke_llm(self, **kwargs: Any) -> str:
|
||||
invoked.update(kwargs)
|
||||
return "ok"
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
|
||||
prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),)
|
||||
result = llm_module._invoke_llm_via_plugin(
|
||||
tenant_id="t",
|
||||
user_id="u",
|
||||
plugin_id="p",
|
||||
provider="prov",
|
||||
model="m",
|
||||
credentials={"k": "v"},
|
||||
model_parameters={"temp": 1},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=None,
|
||||
stop=("a", "b"),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
assert invoked["prompt_messages"] == list(prompt_messages)
|
||||
assert invoked["stop"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None:
|
||||
llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
|
||||
assert (
|
||||
llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
|
||||
chunks = iter([_chunk(content="hello", usage=_usage(1, 1))])
|
||||
result = llm_module._normalize_non_stream_plugin_result(
|
||||
model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks
|
||||
)
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.message.content == "hello"
|
||||
|
||||
|
||||
def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
lambda **_: plugin_result,
|
||||
)
|
||||
cb = SpyCallback()
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb])
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.prompt_messages == prompt_messages
|
||||
assert len(cb.before) == 1
|
||||
assert len(cb.after) == 1
|
||||
assert cb.after[0]["result"].prompt_messages == prompt_messages
|
||||
|
||||
|
||||
def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
plugin_chunks = iter(
|
||||
[
|
||||
_chunk(model="m1", content="a"),
|
||||
_chunk(
|
||||
model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp"
|
||||
),
|
||||
_chunk(model="m3", content=None),
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
lambda **_: plugin_chunks,
|
||||
)
|
||||
|
||||
cb = SpyCallback()
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb])
|
||||
|
||||
assert isinstance(gen, Generator)
|
||||
chunks = list(gen)
|
||||
assert len(chunks) == 3
|
||||
assert all(chunk.prompt_messages == prompt_messages for chunk in chunks)
|
||||
assert len(cb.before) == 1
|
||||
assert len(cb.new_chunk) == 3
|
||||
assert len(cb.after) == 1
|
||||
final_result: LLMResult = cb.after[0]["result"]
|
||||
assert final_result.model == "m3"
|
||||
assert final_result.system_fingerprint == "fp"
|
||||
assert isinstance(final_result.message.content, list)
|
||||
assert [c.data for c in final_result.message.content] == ["a", "b"]
|
||||
assert final_result.usage.total_tokens == 5
|
||||
|
||||
|
||||
def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def boom(**_: Any) -> Any:
|
||||
raise ValueError("plugin down")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom
|
||||
)
|
||||
cb = SpyCallback()
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
llm.invoke(
|
||||
model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb]
|
||||
)
|
||||
assert len(cb.error) == 1
|
||||
assert isinstance(cb.error[0]["ex"], ValueError)
|
||||
|
||||
|
||||
def test_invoke_raises_not_implemented_for_unsupported_result_type(
|
||||
llm: _TestLLM, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result")
|
||||
monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result")
|
||||
with pytest.raises(NotImplementedError, match="unsupported invoke result type"):
|
||||
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
|
||||
|
||||
|
||||
def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured_callbacks: list[list[Callback]] = []
|
||||
|
||||
class FakeLoggingCallback(SpyCallback):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback)
|
||||
monkeypatch.setattr(llm_module.dify_config, "DEBUG", True)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()),
|
||||
)
|
||||
|
||||
original_trigger = llm._trigger_before_invoke_callbacks
|
||||
|
||||
def spy_trigger(*args: Any, **kwargs: Any) -> None:
|
||||
captured_callbacks.append(list(kwargs["callbacks"]))
|
||||
original_trigger(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger)
|
||||
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
|
||||
assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0])
|
||||
|
||||
|
||||
def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False)
|
||||
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0
|
||||
|
||||
|
||||
def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True)
|
||||
|
||||
class FakePluginModelClient:
|
||||
def get_llm_num_tokens(self, **kwargs: Any) -> int:
|
||||
assert kwargs["tenant_id"] == "tenant"
|
||||
assert kwargs["plugin_id"] == "plugin-id"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["model_type"] == "llm"
|
||||
return 42
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42
|
||||
|
||||
|
||||
def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5)
|
||||
llm.started_at = 1.0
|
||||
usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5)
|
||||
assert usage.total_tokens == 15
|
||||
assert usage.total_price == Decimal("0.15")
|
||||
assert usage.latency == 3.5
|
||||
|
||||
|
||||
def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None:
|
||||
def broken() -> Iterator[LLMResultChunk]:
|
||||
yield _chunk(content="ok")
|
||||
raise ValueError("chunk stream broken")
|
||||
|
||||
gen = llm._invoke_result_generator(
|
||||
model="m",
|
||||
result=broken(),
|
||||
credentials={},
|
||||
prompt_messages=[UserPromptMessage(content="u")],
|
||||
model_parameters={},
|
||||
callbacks=[SpyCallback()],
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="transformed: chunk stream broken"):
|
||||
list(gen)
|
||||
@ -0,0 +1,90 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
|
||||
|
||||
class TestModerationModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_model(self, mock_plugin_model_provider):
|
||||
return ModerationModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.MODERATION,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, moderation_model):
|
||||
assert moderation_model.model_type == ModelType.MODERATION
|
||||
|
||||
def test_invoke_success(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
user = "user_123"
|
||||
|
||||
with (
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client_class,
|
||||
patch("time.perf_counter", return_value=1.0),
|
||||
):
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.return_value = True
|
||||
|
||||
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user)
|
||||
|
||||
assert result is True
|
||||
assert moderation_model.started_at == 1.0
|
||||
mock_client.invoke_moderation.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def test_invoke_success_no_user(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.return_value = False
|
||||
|
||||
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text)
|
||||
|
||||
assert result is False
|
||||
mock_client.invoke_moderation.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
moderation_model.invoke(model=model_name, credentials=credentials, text=text)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
@ -0,0 +1,181 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rerank_model() -> RerankModel:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
return RerankModel.model_construct(
|
||||
tenant_id="tenant",
|
||||
model_type=ModelType.RERANK,
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
)
|
||||
|
||||
|
||||
def test_model_type_is_rerank_by_default() -> None:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
model = RerankModel(
|
||||
tenant_id="tenant",
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
)
|
||||
assert model.model_type == ModelType.RERANK
|
||||
|
||||
|
||||
def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)])
|
||||
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.invoke_rerank_called_with: dict[str, Any] | None = None
|
||||
|
||||
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.invoke_rerank_called_with = kwargs
|
||||
return expected
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
|
||||
result = rerank_model.invoke(
|
||||
model="rerank",
|
||||
credentials={"k": "v"},
|
||||
query="q",
|
||||
docs=["d1", "d2"],
|
||||
score_threshold=0.2,
|
||||
top_n=10,
|
||||
user="user-1",
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
assert fake_client.invoke_rerank_called_with == {
|
||||
"tenant_id": "tenant",
|
||||
"user_id": "user-1",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider": "provider",
|
||||
"model": "rerank",
|
||||
"credentials": {"k": "v"},
|
||||
"query": "q",
|
||||
"docs": ["d1", "d2"],
|
||||
"score_threshold": 0.2,
|
||||
"top_n": 10,
|
||||
}
|
||||
|
||||
|
||||
def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.kwargs: dict[str, Any] | None = None
|
||||
|
||||
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.kwargs = kwargs
|
||||
return RerankResult(model="m", docs=[])
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
|
||||
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
|
||||
assert fake_client.kwargs is not None
|
||||
assert fake_client.kwargs["user_id"] == "unknown"
|
||||
|
||||
|
||||
def test_invoke_transforms_and_raises_on_plugin_error(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
class FakePluginModelClient:
|
||||
def invoke_rerank(self, **_: Any) -> RerankResult:
|
||||
raise ValueError("plugin down")
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
|
||||
|
||||
|
||||
def test_invoke_multimodal_calls_plugin_and_passes_args(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)])
|
||||
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None
|
||||
|
||||
def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.invoke_multimodal_rerank_called_with = kwargs
|
||||
return expected
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
|
||||
query = {"type": "text", "text": "q"}
|
||||
docs = [{"type": "text", "text": "d1"}]
|
||||
result = rerank_model.invoke_multimodal_rerank(
|
||||
model="mm",
|
||||
credentials={"k": "v"},
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=None,
|
||||
top_n=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
assert fake_client.invoke_multimodal_rerank_called_with is not None
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant"
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown"
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["query"] == query
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs
|
||||
|
||||
|
||||
def test_invoke_multimodal_transforms_and_raises_on_plugin_error(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
class FakePluginModelClient:
|
||||
def invoke_multimodal_rerank(self, **_: Any) -> RerankResult:
|
||||
raise ValueError("plugin down")
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}])
|
||||
@ -0,0 +1,87 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
|
||||
|
||||
class TestSpeech2TextModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def speech2text_model(self, mock_plugin_model_provider):
|
||||
return Speech2TextModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, speech2text_model):
|
||||
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
|
||||
|
||||
def test_invoke_success(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
user = "user_123"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.return_value = "transcribed text"
|
||||
|
||||
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user)
|
||||
|
||||
assert result == "transcribed text"
|
||||
mock_client.invoke_speech_to_text.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def test_invoke_success_no_user(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.return_value = "transcribed text"
|
||||
|
||||
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
|
||||
|
||||
assert result == "transcribed text"
|
||||
mock_client.invoke_speech_to_text.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
@ -0,0 +1,185 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class TestTextEmbeddingModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def text_embedding_model(self, mock_plugin_model_provider):
|
||||
return TextEmbeddingModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, text_embedding_model):
|
||||
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
|
||||
|
||||
def test_invoke_with_texts(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello", "world"]
|
||||
user = "user_123"
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.return_value = expected_result
|
||||
|
||||
result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user)
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_text_embedding.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
def test_invoke_with_multimodel_documents(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
multimodel_documents = [{"type": "text", "text": "hello"}]
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_multimodal_embedding.return_value = expected_result
|
||||
|
||||
result = text_embedding_model.invoke(
|
||||
model=model_name, credentials=credentials, multimodel_documents=multimodel_documents
|
||||
)
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_multimodal_embedding.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
def test_invoke_no_input(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
text_embedding_model.invoke(model=model_name, credentials=credentials)
|
||||
|
||||
assert "No texts or files provided" in str(excinfo.value)
|
||||
|
||||
def test_invoke_precedence(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello"]
|
||||
multimodel_documents = [{"type": "text", "text": "world"}]
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.return_value = expected_result
|
||||
|
||||
result = text_embedding_model.invoke(
|
||||
model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents
|
||||
)
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_text_embedding.assert_called_once()
|
||||
mock_client.invoke_multimodal_embedding.assert_not_called()
|
||||
|
||||
def test_invoke_exception(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello"]
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
|
||||
def test_get_num_tokens(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello", "world"]
|
||||
expected_tokens = [1, 1]
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.get_text_embedding_num_tokens.return_value = expected_tokens
|
||||
|
||||
result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts)
|
||||
|
||||
assert result == expected_tokens
|
||||
mock_client.get_text_embedding_num_tokens.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def test_get_context_size(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
# Test case 1: Context size in schema
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
|
||||
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 2048
|
||||
|
||||
# Test case 2: No schema
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
|
||||
|
||||
# Test case 3: Context size NOT in schema properties
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
|
||||
|
||||
def test_get_max_chunks(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
# Test case 1: Max chunks in schema
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
|
||||
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 10
|
||||
|
||||
# Test case 2: No schema
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
|
||||
|
||||
# Test case 3: Max chunks NOT in schema properties
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
|
||||
@ -0,0 +1,131 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
|
||||
class TestTTSModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def tts_model(self, mock_plugin_model_provider):
|
||||
return TTSModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.TTS,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, tts_model):
|
||||
assert tts_model.model_type == ModelType.TTS
|
||||
|
||||
def test_invoke_success(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
user = "user_123"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.return_value = [b"audio_chunk"]
|
||||
|
||||
result = tts_model.invoke(
|
||||
model=model_name,
|
||||
tenant_id=tenant_id,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert list(result) == [b"audio_chunk"]
|
||||
mock_client.invoke_tts.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def test_invoke_success_no_user(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.return_value = [b"audio_chunk"]
|
||||
|
||||
result = tts_model.invoke(
|
||||
model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice
|
||||
)
|
||||
|
||||
assert list(result) == [b"audio_chunk"]
|
||||
mock_client.invoke_tts.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
tts_model.invoke(
|
||||
model=model_name,
|
||||
tenant_id=tenant_id,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
|
||||
def test_get_tts_model_voices(self, tts_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
language = "en-US"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}]
|
||||
|
||||
result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language)
|
||||
|
||||
assert result == [{"name": "Voice1"}]
|
||||
mock_client.get_tts_model_voices.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
@ -0,0 +1,96 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module
|
||||
from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
|
||||
class TestGPT2Tokenizer:
|
||||
def setup_method(self):
|
||||
# Reset the global tokenizer before each test to ensure we test initialization
|
||||
gpt2_tokenizer_module._tokenizer = None
|
||||
|
||||
def test_get_encoder_tiktoken(self):
|
||||
"""
|
||||
Test that get_encoder successfully uses tiktoken when available.
|
||||
"""
|
||||
mock_encoding = MagicMock()
|
||||
# Mock tiktoken to be sure it's used
|
||||
with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding:
|
||||
encoder = GPT2Tokenizer.get_encoder()
|
||||
assert encoder == mock_encoding
|
||||
mock_get_encoding.assert_called_once_with("gpt2")
|
||||
|
||||
# Verify singleton behavior within the same test
|
||||
encoder2 = GPT2Tokenizer.get_encoder()
|
||||
assert encoder2 is encoder
|
||||
assert mock_get_encoding.call_count == 1
|
||||
|
||||
def test_get_encoder_tiktoken_fallback(self):
|
||||
"""
|
||||
Test that get_encoder falls back to transformers when tiktoken fails.
|
||||
"""
|
||||
# patch tiktoken.get_encoding to raise an exception
|
||||
with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")):
|
||||
# patch transformers.GPT2Tokenizer
|
||||
with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained:
|
||||
mock_transformer_tokenizer = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_transformer_tokenizer
|
||||
|
||||
with patch(
|
||||
"dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger"
|
||||
) as mock_logger:
|
||||
encoder = GPT2Tokenizer.get_encoder()
|
||||
|
||||
assert encoder == mock_transformer_tokenizer
|
||||
mock_from_pretrained.assert_called_once()
|
||||
mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
|
||||
def test_get_num_tokens(self):
|
||||
"""
|
||||
Test get_num_tokens returns the correct count.
|
||||
"""
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder.encode.return_value = [1, 2, 3, 4, 5]
|
||||
|
||||
with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder):
|
||||
tokens_count = GPT2Tokenizer.get_num_tokens("test text")
|
||||
assert tokens_count == 5
|
||||
mock_encoder.encode.assert_called_once_with("test text")
|
||||
|
||||
def test_get_num_tokens_by_gpt2_direct(self):
|
||||
"""
|
||||
Test _get_num_tokens_by_gpt2 directly.
|
||||
"""
|
||||
mock_encoder = MagicMock()
|
||||
mock_encoder.encode.return_value = [1, 2]
|
||||
|
||||
with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder):
|
||||
tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello")
|
||||
assert tokens_count == 2
|
||||
mock_encoder.encode.assert_called_once_with("hello")
|
||||
|
||||
def test_get_encoder_already_initialized(self):
|
||||
"""
|
||||
Test that if _tokenizer is already set, it returns it immediately.
|
||||
"""
|
||||
mock_existing_tokenizer = MagicMock()
|
||||
gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer
|
||||
|
||||
# Tiktoken should not be called if already initialized
|
||||
with patch("tiktoken.get_encoding") as mock_get_encoding:
|
||||
encoder = GPT2Tokenizer.get_encoder()
|
||||
assert encoder == mock_existing_tokenizer
|
||||
mock_get_encoding.assert_not_called()
|
||||
|
||||
def test_get_encoder_thread_safety(self):
|
||||
"""
|
||||
Simple test to ensure the lock is used.
|
||||
"""
|
||||
mock_encoding = MagicMock()
|
||||
with patch("tiktoken.get_encoding", return_value=mock_encoding):
|
||||
# We patch the lock in the module
|
||||
with patch("dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock:
|
||||
encoder = GPT2Tokenizer.get_encoder()
|
||||
assert encoder == mock_encoding
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
@ -0,0 +1,522 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
def _provider_entity(
|
||||
*,
|
||||
provider: str,
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
models: list[AIModelEntity] | None = None,
|
||||
icon_small: I18nObject | None = None,
|
||||
icon_small_dark: I18nObject | None = None,
|
||||
) -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider=provider,
|
||||
label=I18nObject(en_US=provider),
|
||||
supported_model_types=supported_model_types or [ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=models or [],
|
||||
icon_small=icon_small,
|
||||
icon_small_dark=icon_small_dark,
|
||||
)
|
||||
|
||||
|
||||
def _plugin_provider(
|
||||
*, plugin_id: str, declaration: ProviderEntity, provider: str = "provider"
|
||||
) -> PluginModelProviderEntity:
|
||||
return PluginModelProviderEntity.model_construct(
|
||||
id=f"{plugin_id}-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider=provider,
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier=f"{plugin_id}-uid",
|
||||
plugin_id=plugin_id,
|
||||
declaration=declaration,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_plugin_model_provider_context() -> None:
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
contexts.plugin_model_providers.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
manager = MagicMock()
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory:
|
||||
return ModelProviderFactory(tenant_id="tenant")
|
||||
|
||||
|
||||
def test_get_plugin_model_providers_initializes_context_on_lookup_error(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
original_get = contexts.plugin_model_providers.get
|
||||
calls = {"n": 0}
|
||||
|
||||
def flaky_get() -> Any:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise LookupError
|
||||
return original_get()
|
||||
|
||||
monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get)
|
||||
|
||||
providers = factory.get_plugin_model_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0].declaration.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_get_plugin_model_providers_caches_and_does_not_refetch(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
first = factory.get_plugin_model_providers()
|
||||
second = factory.get_plugin_model_providers()
|
||||
|
||||
assert first is second
|
||||
fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant")
|
||||
|
||||
|
||||
def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
|
||||
d1 = _provider_entity(provider="openai")
|
||||
d2 = _provider_entity(provider="anthropic")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=d1),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=d2),
|
||||
]
|
||||
|
||||
providers = factory.get_providers()
|
||||
assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"]
|
||||
|
||||
|
||||
def test_get_plugin_model_provider_converts_short_provider_id(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
provider = factory.get_plugin_model_provider("openai")
|
||||
assert provider.declaration.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_get_plugin_model_provider_raises_on_invalid_provider(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
factory.get_plugin_model_provider("langgenius/unknown/unknown")
|
||||
|
||||
|
||||
def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
schema = factory.get_provider_schema("openai")
|
||||
assert schema.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_provider_credentials_validate_errors_when_schema_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.provider_credential_schema = None
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"get_plugin_model_provider",
|
||||
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have provider_credential_schema"):
|
||||
factory.provider_credentials_validate(provider="openai", credentials={"x": "y"})
|
||||
|
||||
|
||||
def test_provider_credentials_validate_filters_and_calls_plugin_validation(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.provider_credential_schema = MagicMock()
|
||||
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
|
||||
|
||||
fake_validator = MagicMock()
|
||||
fake_validator.validate_and_filter.return_value = {"filtered": True}
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator",
|
||||
lambda _: fake_validator,
|
||||
)
|
||||
|
||||
filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True})
|
||||
assert filtered == {"filtered": True}
|
||||
fake_plugin_manager.validate_provider_credentials.assert_called_once()
|
||||
kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs
|
||||
assert kwargs["plugin_id"] == "langgenius/openai"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["credentials"] == {"filtered": True}
|
||||
|
||||
|
||||
def test_model_credentials_validate_errors_when_schema_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.model_credential_schema = None
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"get_plugin_model_provider",
|
||||
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have model_credential_schema"):
|
||||
factory.model_credentials_validate(
|
||||
provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"}
|
||||
)
|
||||
|
||||
|
||||
def test_model_credentials_validate_filters_and_calls_plugin_validation(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.model_credential_schema = MagicMock()
|
||||
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
|
||||
|
||||
fake_validator = MagicMock()
|
||||
fake_validator.validate_and_filter.return_value = {"filtered": True}
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator",
|
||||
lambda *_: fake_validator,
|
||||
)
|
||||
|
||||
filtered = factory.model_credentials_validate(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True}
|
||||
)
|
||||
assert filtered == {"filtered": True}
|
||||
kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs
|
||||
assert kwargs["plugin_id"] == "langgenius/openai"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["model_type"] == "text-embedding"
|
||||
assert kwargs["model"] == "m"
|
||||
assert kwargs["credentials"] == {"filtered": True}
|
||||
|
||||
|
||||
def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model_schema = AIModelEntity(
|
||||
model="m",
|
||||
label=I18nObject(en_US="m"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = model_schema.model_dump_json().encode()
|
||||
assert (
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"})
|
||||
== model_schema
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_invalid_json_deletes_key(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b'{"model":"m"}'
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
|
||||
assert mock_redis.delete.called
|
||||
assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_delete_redis_error_is_logged(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b'{"model":"m"}'
|
||||
mock_redis.delete.side_effect = RedisError("nope")
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
|
||||
assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_redis_get_error_falls_back_to_plugin(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.side_effect = RedisError("down")
|
||||
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
|
||||
assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
|
||||
model_schema = AIModelEntity(
|
||||
model="m",
|
||||
label=I18nObject(en_US="m"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
factory.plugin_model_manager.get_model_schema.return_value = model_schema
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = RedisError("nope")
|
||||
assert (
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
|
||||
== model_schema
|
||||
)
|
||||
assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_type", "expected_class"),
|
||||
[
|
||||
(ModelType.LLM, "LargeLanguageModel"),
|
||||
(ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"),
|
||||
(ModelType.RERANK, "RerankModel"),
|
||||
(ModelType.SPEECH2TEXT, "Speech2TextModel"),
|
||||
(ModelType.MODERATION, "ModerationModel"),
|
||||
(ModelType.TTS, "TTSModel"),
|
||||
],
|
||||
)
|
||||
def test_get_model_type_instance_dispatches_by_type(
|
||||
factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
|
||||
|
||||
sentinel = object()
|
||||
monkeypatch.setattr(
|
||||
f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}",
|
||||
MagicMock(model_validate=lambda _: sentinel),
|
||||
)
|
||||
|
||||
assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel
|
||||
|
||||
|
||||
def test_get_model_type_instance_raises_on_unsupported(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
|
||||
|
||||
class UnknownModelType:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type"):
|
||||
factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_models_filters_by_provider_and_model_type(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
llm = AIModelEntity(
|
||||
model="m1",
|
||||
label=I18nObject(en_US="m1"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
embed = AIModelEntity(
|
||||
model="e1",
|
||||
label=I18nObject(en_US="e1"),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
openai = _provider_entity(
|
||||
provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed]
|
||||
)
|
||||
anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm])
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
|
||||
]
|
||||
|
||||
# ModelType filter picks only matching models
|
||||
providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/openai/openai"
|
||||
assert [m.model for m in providers[0].models] == ["e1"]
|
||||
|
||||
# Provider filter excludes others
|
||||
providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/anthropic/anthropic"
|
||||
|
||||
|
||||
def test_get_models_provider_filter_skips_non_matching(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
openai = _provider_entity(provider="openai")
|
||||
anthropic = _provider_entity(provider="anthropic")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
|
||||
]
|
||||
|
||||
providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM)
|
||||
assert providers == []
|
||||
|
||||
|
||||
def test_get_provider_icon_fetches_asset_and_returns_mime_type(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"),
|
||||
icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
class FakePluginAssetManager:
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
assert tenant_id == "tenant"
|
||||
return f"bytes:{id}".encode()
|
||||
|
||||
import core.plugin.impl.asset as asset_module
|
||||
|
||||
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
|
||||
|
||||
data, mime = factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
assert data == b"bytes:icon.png"
|
||||
assert mime == "image/png"
|
||||
|
||||
data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans")
|
||||
assert data == b"bytes:dark-zh.svg"
|
||||
assert mime == "image/svg+xml"
|
||||
|
||||
|
||||
def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"),
|
||||
icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
class FakePluginAssetManager:
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
return id.encode()
|
||||
|
||||
import core.plugin.impl.asset as asset_module
|
||||
|
||||
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
|
||||
|
||||
data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans")
|
||||
assert data == b"icon-zh.png"
|
||||
|
||||
data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US")
|
||||
assert data == b"dark-en.svg"
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_for_missing_icons(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(provider="langgenius/openai/openai")
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have small icon"):
|
||||
factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
|
||||
with pytest.raises(ValueError, match="does not have small dark icon"):
|
||||
factory.get_provider_icon("openai", "icon_small_dark", "en_US")
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_for_unsupported_icon_type(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
with pytest.raises(ValueError, match="Unsupported icon type"):
|
||||
factory.get_provider_icon("openai", "nope", "en_US")
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_when_file_name_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
with pytest.raises(ValueError, match="does not have icon"):
|
||||
factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
|
||||
|
||||
def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case(
|
||||
factory: ModelProviderFactory,
|
||||
) -> None:
|
||||
plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google")
|
||||
assert plugin_id == "langgenius/gemini"
|
||||
assert provider_name == "google"
|
||||
@ -0,0 +1,201 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormOption,
|
||||
FormShowOnObject,
|
||||
FormType,
|
||||
)
|
||||
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
|
||||
|
||||
|
||||
class TestCommonValidator:
|
||||
def test_validate_credential_form_schema_required_missing(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="Variable api_key is required"):
|
||||
validator._validate_credential_form_schema(schema, {})
|
||||
|
||||
def test_validate_credential_form_schema_not_required_missing_with_default(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
default="default_value",
|
||||
)
|
||||
assert validator._validate_credential_form_schema(schema, {}) == "default_value"
|
||||
|
||||
def test_validate_credential_form_schema_not_required_missing_no_default(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False
|
||||
)
|
||||
assert validator._validate_credential_form_schema(schema, {}) is None
|
||||
|
||||
def test_validate_credential_form_schema_max_length_exceeded(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5
|
||||
)
|
||||
with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"):
|
||||
validator._validate_credential_form_schema(schema, {"api_key": "123456"})
|
||||
|
||||
def test_validate_credential_form_schema_not_string(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT)
|
||||
with pytest.raises(ValueError, match="Variable api_key should be string"):
|
||||
validator._validate_credential_form_schema(schema, {"api_key": 123})
|
||||
|
||||
def test_validate_credential_form_schema_select_invalid_option(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="mode",
|
||||
label=I18nObject(en_US="Mode"),
|
||||
type=FormType.SELECT,
|
||||
options=[
|
||||
FormOption(label=I18nObject(en_US="Fast"), value="fast"),
|
||||
FormOption(label=I18nObject(en_US="Slow"), value="slow"),
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="Variable mode is not in options"):
|
||||
validator._validate_credential_form_schema(schema, {"mode": "medium"})
|
||||
|
||||
def test_validate_credential_form_schema_select_valid_option(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(
|
||||
variable="mode",
|
||||
label=I18nObject(en_US="Mode"),
|
||||
type=FormType.SELECT,
|
||||
options=[
|
||||
FormOption(label=I18nObject(en_US="Fast"), value="fast"),
|
||||
FormOption(label=I18nObject(en_US="Slow"), value="slow"),
|
||||
],
|
||||
)
|
||||
assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast"
|
||||
|
||||
def test_validate_credential_form_schema_switch_invalid(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH)
|
||||
with pytest.raises(ValueError, match="Variable enabled should be true or false"):
|
||||
validator._validate_credential_form_schema(schema, {"enabled": "maybe"})
|
||||
|
||||
def test_validate_credential_form_schema_switch_valid(self):
|
||||
validator = CommonValidator()
|
||||
schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH)
|
||||
assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True
|
||||
assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False
|
||||
|
||||
def test_validate_and_filter_credential_form_schemas_with_show_on(self):
|
||||
validator = CommonValidator()
|
||||
schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="auth_type",
|
||||
label=I18nObject(en_US="Auth Type"),
|
||||
type=FormType.SELECT,
|
||||
options=[
|
||||
FormOption(label=I18nObject(en_US="API Key"), value="api_key"),
|
||||
FormOption(label=I18nObject(en_US="OAuth"), value="oauth"),
|
||||
],
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="client_id",
|
||||
label=I18nObject(en_US="Client ID"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
show_on=[FormShowOnObject(variable="auth_type", value="oauth")],
|
||||
),
|
||||
]
|
||||
|
||||
# Case 1: auth_type = api_key
|
||||
credentials = {"auth_type": "api_key", "api_key": "my_secret"}
|
||||
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
|
||||
assert "auth_type" in result
|
||||
assert "api_key" in result
|
||||
assert "client_id" not in result
|
||||
assert result["api_key"] == "my_secret"
|
||||
|
||||
# Case 2: auth_type = oauth
|
||||
credentials = {"auth_type": "oauth", "client_id": "my_client"}
|
||||
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
|
||||
# Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation.
|
||||
# Since 'oauth' is not an empty string, it is in result.
|
||||
assert "auth_type" in result
|
||||
assert "api_key" not in result
|
||||
assert "client_id" in result
|
||||
assert result["client_id"] == "my_client"
|
||||
|
||||
def test_validate_and_filter_show_on_missing_variable(self):
|
||||
validator = CommonValidator()
|
||||
schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
|
||||
)
|
||||
]
|
||||
# auth_type is missing in credentials, so api_key should be filtered out
|
||||
result = validator._validate_and_filter_credential_form_schemas(schemas, {})
|
||||
assert result == {}
|
||||
|
||||
def test_validate_and_filter_show_on_mismatch_value(self):
|
||||
validator = CommonValidator()
|
||||
schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
|
||||
)
|
||||
]
|
||||
# auth_type is oauth, which doesn't match show_on
|
||||
result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"})
|
||||
assert result == {}
|
||||
|
||||
def test_validate_and_filter_multiple_show_on(self):
|
||||
validator = CommonValidator()
|
||||
schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="target",
|
||||
label=I18nObject(en_US="Target"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")],
|
||||
)
|
||||
]
|
||||
# Both match
|
||||
assert "target" in validator._validate_and_filter_credential_form_schemas(
|
||||
schemas, {"v1": "a", "v2": "b", "target": "val"}
|
||||
)
|
||||
# One mismatch
|
||||
assert "target" not in validator._validate_and_filter_credential_form_schemas(
|
||||
schemas, {"v1": "a", "v2": "c", "target": "val"}
|
||||
)
|
||||
# One missing
|
||||
assert "target" not in validator._validate_and_filter_credential_form_schemas(
|
||||
schemas, {"v1": "a", "target": "val"}
|
||||
)
|
||||
|
||||
def test_validate_and_filter_skips_falsy_results(self):
|
||||
validator = CommonValidator()
|
||||
schemas = [
|
||||
CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH),
|
||||
CredentialFormSchema(
|
||||
variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False
|
||||
),
|
||||
]
|
||||
# Result of false switch is False. if result: is false. Not added.
|
||||
# Result of empty string is "", if result: is false. Not added.
|
||||
credentials = {"enabled": "false", "empty_str": ""}
|
||||
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
|
||||
assert "enabled" not in result
|
||||
assert "empty_str" not in result
|
||||
@ -0,0 +1,233 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FieldModelSchema,
|
||||
FormOption,
|
||||
FormShowOnObject,
|
||||
FormType,
|
||||
ModelCredentialSchema,
|
||||
)
|
||||
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
|
||||
|
||||
def test_validate_and_filter_with_none_schema():
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, None)
|
||||
with pytest.raises(ValueError, match="Model credential schema is None"):
|
||||
validator.validate_and_filter({})
|
||||
|
||||
|
||||
def test_validate_and_filter_success():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API Key"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="optional_field",
|
||||
label=I18nObject(en_US="Optional", zh_Hans="可选"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
default="default_val",
|
||||
),
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
credentials = {"api_key": "sk-123456"}
|
||||
result = validator.validate_and_filter(credentials)
|
||||
|
||||
assert result["api_key"] == "sk-123456"
|
||||
assert result["optional_field"] == "default_val"
|
||||
assert credentials["__model_type"] == ModelType.LLM.value
|
||||
|
||||
|
||||
def test_validate_and_filter_with_show_on():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="conditional_field",
|
||||
label=I18nObject(en_US="Conditional", zh_Hans="条件"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
show_on=[FormShowOnObject(variable="mode", value="advanced")],
|
||||
),
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
# mode is 'simple', conditional_field should be filtered out
|
||||
credentials = {"mode": "simple", "conditional_field": "secret"}
|
||||
result = validator.validate_and_filter(credentials)
|
||||
assert "conditional_field" not in result
|
||||
assert result["mode"] == "simple"
|
||||
|
||||
# mode is 'advanced', conditional_field should be kept
|
||||
credentials = {"mode": "advanced", "conditional_field": "secret"}
|
||||
result = validator.validate_and_filter(credentials)
|
||||
assert result["conditional_field"] == "secret"
|
||||
assert result["mode"] == "advanced"
|
||||
|
||||
# show_on variable missing in credentials
|
||||
credentials = {"conditional_field": "secret"} # mode missing
|
||||
with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema
|
||||
validator.validate_and_filter(credentials)
|
||||
|
||||
|
||||
def test_validate_and_filter_show_on_missing_trigger_var():
|
||||
# specifically test all_show_on_match = False when variable not in credentials
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="optional_trigger",
|
||||
label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="conditional_field",
|
||||
label=I18nObject(en_US="Conditional", zh_Hans="条件"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
show_on=[FormShowOnObject(variable="optional_trigger", value="active")],
|
||||
),
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
# optional_trigger missing, conditional_field should be skipped
|
||||
result = validator.validate_and_filter({"conditional_field": "val"})
|
||||
assert "conditional_field" not in result
|
||||
|
||||
|
||||
def test_common_validator_logic_required():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API Key"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
with pytest.raises(ValueError, match="Variable api_key is required"):
|
||||
validator.validate_and_filter({})
|
||||
|
||||
with pytest.raises(ValueError, match="Variable api_key is required"):
|
||||
validator.validate_and_filter({"api_key": ""})
|
||||
|
||||
|
||||
def test_common_validator_logic_max_length():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="key",
|
||||
label=I18nObject(en_US="Key", zh_Hans="Key"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
max_length=5,
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
with pytest.raises(ValueError, match="Variable key length should not be greater than 5"):
|
||||
validator.validate_and_filter({"key": "123456"})
|
||||
|
||||
|
||||
def test_common_validator_logic_invalid_type():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
with pytest.raises(ValueError, match="Variable key should be string"):
|
||||
validator.validate_and_filter({"key": 123})
|
||||
|
||||
|
||||
def test_common_validator_logic_switch():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="enabled",
|
||||
label=I18nObject(en_US="Enabled", zh_Hans="启用"),
|
||||
type=FormType.SWITCH,
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
result = validator.validate_and_filter({"enabled": "true"})
|
||||
assert result["enabled"] is True
|
||||
|
||||
result = validator.validate_and_filter({"enabled": "false"})
|
||||
assert "enabled" not in result
|
||||
|
||||
with pytest.raises(ValueError, match="Variable enabled should be true or false"):
|
||||
validator.validate_and_filter({"enabled": "not_a_bool"})
|
||||
|
||||
|
||||
def test_common_validator_logic_options():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="choice",
|
||||
label=I18nObject(en_US="Choice", zh_Hans="选择"),
|
||||
type=FormType.SELECT,
|
||||
required=True,
|
||||
options=[
|
||||
FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"),
|
||||
FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
result = validator.validate_and_filter({"choice": "a"})
|
||||
assert result["choice"] == "a"
|
||||
|
||||
with pytest.raises(ValueError, match="Variable choice is not in options"):
|
||||
validator.validate_and_filter({"choice": "c"})
|
||||
|
||||
|
||||
def test_validate_and_filter_optional_no_default():
|
||||
schema = ModelCredentialSchema(
|
||||
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="optional",
|
||||
label=I18nObject(en_US="Optional", zh_Hans="可选"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
|
||||
|
||||
result = validator.validate_and_filter({})
|
||||
assert "optional" not in result
|
||||
@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema
|
||||
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
|
||||
ProviderCredentialSchemaValidator,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderCredentialSchemaValidator:
|
||||
def test_validate_and_filter_success(self):
|
||||
# Setup schema
|
||||
schema = ProviderCredentialSchema(
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="endpoint",
|
||||
label=I18nObject(en_US="Endpoint"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=False,
|
||||
default="https://api.example.com",
|
||||
),
|
||||
]
|
||||
)
|
||||
validator = ProviderCredentialSchemaValidator(schema)
|
||||
|
||||
# Test valid credentials
|
||||
credentials = {"api_key": "my-secret-key"}
|
||||
result = validator.validate_and_filter(credentials)
|
||||
|
||||
assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"}
|
||||
|
||||
def test_validate_and_filter_missing_required(self):
|
||||
# Setup schema
|
||||
schema = ProviderCredentialSchema(
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
|
||||
)
|
||||
]
|
||||
)
|
||||
validator = ProviderCredentialSchemaValidator(schema)
|
||||
|
||||
# Test missing required credentials
|
||||
with pytest.raises(ValueError, match="Variable api_key is required"):
|
||||
validator.validate_and_filter({})
|
||||
|
||||
def test_validate_and_filter_extra_fields_filtered(self):
|
||||
# Setup schema
|
||||
schema = ProviderCredentialSchema(
|
||||
credential_form_schemas=[
|
||||
CredentialFormSchema(
|
||||
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
|
||||
)
|
||||
]
|
||||
)
|
||||
validator = ProviderCredentialSchemaValidator(schema)
|
||||
|
||||
# Test credentials with extra fields
|
||||
credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"}
|
||||
result = validator.validate_and_filter(credentials)
|
||||
|
||||
assert "api_key" in result
|
||||
assert "extra_field" not in result
|
||||
assert result == {"api_key": "my-secret-key"}
|
||||
|
||||
def test_init(self):
|
||||
schema = ProviderCredentialSchema(credential_form_schemas=[])
|
||||
validator = ProviderCredentialSchemaValidator(schema)
|
||||
assert validator.provider_credential_schema == schema
|
||||
@ -0,0 +1,231 @@
|
||||
import dataclasses
|
||||
import datetime
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path, PurePath
|
||||
from re import compile
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from pydantic_core import Url
|
||||
from pydantic_extra_types.color import Color
|
||||
|
||||
from dify_graph.model_runtime.utils.encoders import (
|
||||
_model_dump,
|
||||
decimal_encoder,
|
||||
generate_encoders_by_class_tuples,
|
||||
isoformat,
|
||||
jsonable_encoder,
|
||||
)
|
||||
|
||||
|
||||
class MockEnum(Enum):
|
||||
A = "a"
|
||||
B = "b"
|
||||
|
||||
|
||||
class MockPydanticModel(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockDataclass:
|
||||
name: str
|
||||
value: Any
|
||||
|
||||
|
||||
class MockWithDict:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data.items())
|
||||
|
||||
|
||||
class MockWithVars:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class TestEncoders:
|
||||
def test_model_dump(self):
|
||||
model = MockPydanticModel(name="test", age=20)
|
||||
result = _model_dump(model)
|
||||
assert result == {"name": "test", "age": 20}
|
||||
|
||||
def test_isoformat(self):
|
||||
d = datetime.date(2023, 1, 1)
|
||||
assert isoformat(d) == "2023-01-01"
|
||||
t = datetime.time(12, 0, 0)
|
||||
assert isoformat(t) == "12:00:00"
|
||||
|
||||
def test_decimal_encoder(self):
|
||||
assert decimal_encoder(Decimal("1.0")) == 1.0
|
||||
assert decimal_encoder(Decimal(1)) == 1
|
||||
assert decimal_encoder(Decimal("1.5")) == 1.5
|
||||
assert decimal_encoder(Decimal(0)) == 0
|
||||
assert decimal_encoder(Decimal(-1)) == -1
|
||||
|
||||
def test_generate_encoders_by_class_tuples(self):
|
||||
type_map = {int: str, float: str, str: int}
|
||||
result = generate_encoders_by_class_tuples(type_map)
|
||||
assert result[str] == (int, float)
|
||||
assert result[int] == (str,)
|
||||
|
||||
def test_jsonable_encoder_basic_types(self):
|
||||
assert jsonable_encoder("string") == "string"
|
||||
assert jsonable_encoder(123) == 123
|
||||
assert jsonable_encoder(1.23) == 1.23
|
||||
assert jsonable_encoder(None) is None
|
||||
|
||||
def test_jsonable_encoder_pydantic(self):
|
||||
model = MockPydanticModel(name="test", age=20)
|
||||
assert jsonable_encoder(model) == {"name": "test", "age": 20}
|
||||
|
||||
def test_jsonable_encoder_pydantic_root(self):
|
||||
# Manually create a mock that behaves like a model with __root__
|
||||
# because Pydantic v2 handles root differently, but the code checks for "__root__"
|
||||
model = MagicMock(spec=BaseModel)
|
||||
# _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...)
|
||||
model.model_dump.return_value = {"__root__": [1, 2, 3]}
|
||||
assert jsonable_encoder(model) == [1, 2, 3]
|
||||
|
||||
def test_jsonable_encoder_dataclass(self):
|
||||
obj = MockDataclass(name="test", value=1)
|
||||
assert jsonable_encoder(obj) == {"name": "test", "value": 1}
|
||||
# Test dataclass type (should not be treated as instance)
|
||||
# It should fall back to vars() or dict() or at least not crash
|
||||
with pytest.raises(ValueError):
|
||||
jsonable_encoder(MockDataclass)
|
||||
|
||||
def test_jsonable_encoder_enum(self):
|
||||
assert jsonable_encoder(MockEnum.A) == "a"
|
||||
|
||||
def test_jsonable_encoder_path(self):
|
||||
assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test"
|
||||
assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test"
|
||||
|
||||
def test_jsonable_encoder_decimal(self):
|
||||
# In jsonable_encoder, Decimal is formatted as string via format(obj, "f")
|
||||
assert jsonable_encoder(Decimal("1.23")) == "1.23"
|
||||
assert jsonable_encoder(Decimal("1.000")) == "1.000"
|
||||
|
||||
def test_jsonable_encoder_dict(self):
|
||||
d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"}
|
||||
assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]}
|
||||
assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"}
|
||||
|
||||
d_with_none = {"a": 1, "b": None}
|
||||
assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1}
|
||||
assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None}
|
||||
|
||||
def test_jsonable_encoder_collections(self):
|
||||
assert jsonable_encoder([1, 2]) == [1, 2]
|
||||
assert jsonable_encoder((1, 2)) == [1, 2]
|
||||
assert jsonable_encoder({1, 2}) == [1, 2]
|
||||
assert jsonable_encoder(frozenset([1, 2])) == [1, 2]
|
||||
assert jsonable_encoder(deque([1, 2])) == [1, 2]
|
||||
|
||||
def gen():
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
assert jsonable_encoder(gen()) == [1, 2]
|
||||
|
||||
def test_jsonable_encoder_custom_encoder(self):
|
||||
custom = {int: lambda x: str(x + 1)}
|
||||
assert jsonable_encoder(1, custom_encoder=custom) == "2"
|
||||
|
||||
# Test subclass matching for custom encoder
|
||||
class SubInt(int):
|
||||
pass
|
||||
|
||||
assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2"
|
||||
|
||||
def test_jsonable_encoder_special_types(self):
|
||||
# These hit ENCODERS_BY_TYPE or encoders_by_class_tuples
|
||||
assert jsonable_encoder(b"bytes") == "bytes"
|
||||
assert jsonable_encoder(Color("red")) == "red"
|
||||
|
||||
dt = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
assert jsonable_encoder(dt) == dt.isoformat()
|
||||
|
||||
date = datetime.date(2023, 1, 1)
|
||||
assert jsonable_encoder(date) == date.isoformat()
|
||||
|
||||
time = datetime.time(12, 0, 0)
|
||||
assert jsonable_encoder(time) == time.isoformat()
|
||||
|
||||
td = datetime.timedelta(seconds=60)
|
||||
assert jsonable_encoder(td) == 60.0
|
||||
|
||||
assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1"
|
||||
assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24"
|
||||
assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24"
|
||||
assert jsonable_encoder(IPv6Address("::1")) == "::1"
|
||||
assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128"
|
||||
assert jsonable_encoder(IPv6Network("::/128")) == "::/128"
|
||||
|
||||
assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test <test@example.com>"
|
||||
|
||||
assert jsonable_encoder(compile("abc")) == "abc"
|
||||
|
||||
# Secret types
|
||||
# Check what they actually return in this environment
|
||||
res_bytes = jsonable_encoder(SecretBytes(b"secret"))
|
||||
assert "**********" in res_bytes
|
||||
|
||||
res_str = jsonable_encoder(SecretStr("secret"))
|
||||
assert res_str == "**********"
|
||||
|
||||
u = UUID("12345678-1234-5678-1234-567812345678")
|
||||
assert jsonable_encoder(u) == str(u)
|
||||
|
||||
url = AnyUrl("https://example.com")
|
||||
assert jsonable_encoder(url) == "https://example.com/"
|
||||
|
||||
purl = Url("https://example.com")
|
||||
assert jsonable_encoder(purl) == "https://example.com/"
|
||||
|
||||
def test_jsonable_encoder_fallback(self):
|
||||
# dict(obj) success
|
||||
obj_dict = MockWithDict({"a": 1})
|
||||
assert jsonable_encoder(obj_dict) == {"a": 1}
|
||||
|
||||
# vars(obj) success
|
||||
obj_vars = MockWithVars(x=10, y=20)
|
||||
assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20}
|
||||
|
||||
# error fallback
|
||||
class ReallyUnserializable:
|
||||
__slots__ = ["__weakref__"] # No __dict__
|
||||
|
||||
def __iter__(self):
|
||||
raise TypeError("not iterable")
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
jsonable_encoder(ReallyUnserializable())
|
||||
assert "not iterable" in str(exc.value)
|
||||
|
||||
def test_jsonable_encoder_nested(self):
|
||||
data = {
|
||||
"model": MockPydanticModel(name="test", age=20),
|
||||
"list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}],
|
||||
"set": {1, 2},
|
||||
}
|
||||
expected = {
|
||||
"model": {"name": "test", "age": 20},
|
||||
"list": ["1.1", {"a": "/tmp"}],
|
||||
"set": [1, 2],
|
||||
}
|
||||
assert jsonable_encoder(data) == expected
|
||||
Loading…
Reference in New Issue
Block a user