test: added test for core token buffer memory and model runtime (#32512)

Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
mahammadasim 2026-03-12 09:16:46 +05:30 committed by GitHub
parent 60fe5e7f00
commit e99628b76f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 6007 additions and 6 deletions

View File

@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_call_id:
return False
return True
return super().is_empty() and not self.tool_call_id

View File

@ -4,7 +4,8 @@ class InvokeError(ValueError):
description: str | None = None
def __init__(self, description: str | None = None):
self.description = description
if description is not None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__

View File

@ -282,7 +282,8 @@ class ModelProviderFactory:
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
simple_provider_schema.models.extend(all_model_type_models)
if model_type:
simple_provider_schema.models = all_model_type_models
providers.append(simple_provider_schema)

View File

@ -0,0 +1,969 @@
"""Comprehensive unit tests for core/memory/token_buffer_memory.py"""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.memory.token_buffer_memory import TokenBufferMemory
from dify_graph.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from models.model import AppMode
# ---------------------------------------------------------------------------
# Helpers / shared fixtures
# ---------------------------------------------------------------------------
def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock:
"""Return a minimal Conversation mock."""
conv = MagicMock()
conv.id = str(uuid4())
conv.mode = mode
conv.model_config = {}
return conv
def _make_model_instance() -> MagicMock:
"""Return a ModelInstance mock whose token counter returns a constant."""
mi = MagicMock()
mi.get_llm_num_tokens.return_value = 100
return mi
def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock:
msg = MagicMock()
msg.id = str(uuid4())
msg.query = "user query"
msg.answer = answer
msg.answer_tokens = answer_tokens
msg.workflow_run_id = str(uuid4())
msg.created_at = MagicMock()
return msg
# ===========================================================================
# Tests for __init__ and workflow_run_repo property
# ===========================================================================
class TestInit:
def test_init_stores_conversation_and_model_instance(self):
conv = _make_conversation()
mi = _make_model_instance()
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
assert mem.conversation is conv
assert mem.model_instance is mi
assert mem._workflow_run_repo is None
def test_workflow_run_repo_is_created_lazily(self):
conv = _make_conversation()
mi = _make_model_instance()
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
mock_repo = MagicMock()
with (
patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm,
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=mock_repo,
),
):
mock_db.engine = MagicMock()
repo = mem.workflow_run_repo
assert repo is mock_repo
assert mem._workflow_run_repo is mock_repo
def test_workflow_run_repo_cached_after_first_access(self):
conv = _make_conversation()
mi = _make_model_instance()
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
existing_repo = MagicMock()
mem._workflow_run_repo = existing_repo
with patch(
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
) as mock_factory:
repo = mem.workflow_run_repo
mock_factory.assert_not_called()
assert repo is existing_repo
# ===========================================================================
# Tests for _build_prompt_message_with_files
# ===========================================================================
class TestBuildPromptMessageWithFiles:
"""Tests for the private _build_prompt_message_with_files method."""
# ------------------------------------------------------------------
# Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch)
# ------------------------------------------------------------------
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_no_files_user_message(self, mode):
"""When file_extra_config is falsy or app_record is None → plain UserPromptMessage."""
conv = _make_conversation(mode)
mi = _make_model_instance()
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
with patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None, # falsy → file_objs = []
):
result = mem._build_prompt_message_with_files(
message_files=[],
text_content="hello",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
assert isinstance(result, UserPromptMessage)
assert result.content == "hello"
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_no_files_assistant_message(self, mode):
"""Plain AssistantPromptMessage when no files and is_user_message=False."""
conv = _make_conversation(mode)
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
with patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
):
result = mem._build_prompt_message_with_files(
message_files=[],
text_content="ai reply",
message=_make_message(),
app_record=None,
is_user_message=False,
)
assert isinstance(result, AssistantPromptMessage)
assert result.content == "ai reply"
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_with_files_user_message(self, mode):
"""When files are present, returns UserPromptMessage with list content."""
conv = _make_conversation(mode)
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mock_file_extra_config = MagicMock()
mock_file_extra_config.image_config = None # no detail override
mock_file_obj = MagicMock()
# Must be a real entity so Pydantic's tagged union discriminator can validate it
real_image_content = ImagePromptMessageContent(
url="http://example.com/img.png", format="png", mime_type="image/png"
)
mock_message_file = MagicMock()
mock_app_record = MagicMock()
mock_app_record.tenant_id = "tenant-1"
with (
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=mock_file_extra_config,
),
patch(
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
return_value=mock_file_obj,
),
patch(
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
return_value=real_image_content,
),
):
result = mem._build_prompt_message_with_files(
message_files=[mock_message_file],
text_content="user text",
message=_make_message(),
app_record=mock_app_record,
is_user_message=True,
)
assert isinstance(result, UserPromptMessage)
assert isinstance(result.content, list)
# Last element should be TextPromptMessageContent
assert isinstance(result.content[-1], TextPromptMessageContent)
assert result.content[-1].data == "user text"
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_with_files_assistant_message(self, mode):
"""When files are present, returns AssistantPromptMessage with list content."""
conv = _make_conversation(mode)
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mock_file_extra_config = MagicMock()
mock_file_extra_config.image_config = None
mock_file_obj = MagicMock()
real_image_content = ImagePromptMessageContent(
url="http://example.com/img.png", format="png", mime_type="image/png"
)
mock_app_record = MagicMock()
mock_app_record.tenant_id = "tenant-1"
with (
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=mock_file_extra_config,
),
patch(
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
return_value=mock_file_obj,
),
patch(
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
return_value=real_image_content,
),
):
result = mem._build_prompt_message_with_files(
message_files=[MagicMock()],
text_content="ai text",
message=_make_message(),
app_record=mock_app_record,
is_user_message=False,
)
assert isinstance(result, AssistantPromptMessage)
assert isinstance(result.content, list)
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_with_files_image_detail_overridden(self, mode):
"""When image_config.detail is set, detail is taken from config."""
conv = _make_conversation(mode)
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mock_image_config = MagicMock()
mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW
mock_file_extra_config = MagicMock()
mock_file_extra_config.image_config = mock_image_config
mock_app_record = MagicMock()
mock_app_record.tenant_id = "tenant-1"
real_image_content = ImagePromptMessageContent(
url="http://example.com/img.png", format="png", mime_type="image/png"
)
with (
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=mock_file_extra_config,
),
patch(
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
return_value=MagicMock(),
),
patch(
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
return_value=real_image_content,
) as mock_to_prompt,
):
mem._build_prompt_message_with_files(
message_files=[MagicMock()],
text_content="user text",
message=_make_message(),
app_record=mock_app_record,
is_user_message=True,
)
# Ensure the LOW detail was passed through
mock_to_prompt.assert_called_once_with(
mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW
)
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode):
"""app_record=None path → file_objs stays empty → plain messages."""
conv = _make_conversation(mode)
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mock_file_extra_config = MagicMock()
with patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=mock_file_extra_config,
):
result = mem._build_prompt_message_with_files(
message_files=[MagicMock()],
text_content="hello",
message=_make_message(),
app_record=None, # <-- forces the else branch → file_objs = []
is_user_message=True,
)
assert isinstance(result, UserPromptMessage)
assert result.content == "hello"
# ------------------------------------------------------------------
# Mode: ADVANCED_CHAT / WORKFLOW
# ------------------------------------------------------------------
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def test_workflow_mode_no_app_raises(self, mode):
"""Raises ValueError when conversation.app is falsy."""
conv = _make_conversation(mode)
conv.app = None
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
with pytest.raises(ValueError, match="App not found for conversation"):
mem._build_prompt_message_with_files(
message_files=[],
text_content="text",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def test_workflow_mode_no_workflow_run_id_raises(self, mode):
"""Raises ValueError when message.workflow_run_id is falsy."""
conv = _make_conversation(mode)
conv.app = MagicMock()
message = _make_message()
message.workflow_run_id = None # force missing
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
with pytest.raises(ValueError, match="Workflow run ID not found"):
mem._build_prompt_message_with_files(
message_files=[],
text_content="text",
message=message,
app_record=MagicMock(),
is_user_message=True,
)
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def test_workflow_mode_workflow_run_not_found_raises(self, mode):
"""Raises ValueError when workflow_run_repo returns None."""
conv = _make_conversation(mode)
mock_app = MagicMock()
conv.app = mock_app
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mem._workflow_run_repo = MagicMock()
mem._workflow_run_repo.get_workflow_run_by_id.return_value = None
with pytest.raises(ValueError, match="Workflow run not found"):
mem._build_prompt_message_with_files(
message_files=[],
text_content="text",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def test_workflow_mode_workflow_not_found_raises(self, mode):
"""Raises ValueError when Workflow lookup returns None."""
conv = _make_conversation(mode)
conv.app = MagicMock()
mock_workflow_run = MagicMock()
mock_workflow_run.workflow_id = str(uuid4())
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mem._workflow_run_repo = MagicMock()
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
):
mock_db.session.scalar.return_value = None # workflow not found
with pytest.raises(ValueError, match="Workflow not found"):
mem._build_prompt_message_with_files(
message_files=[],
text_content="text",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def test_workflow_mode_success_no_files_user(self, mode):
"""Happy path: workflow mode, no message files → plain UserPromptMessage."""
conv = _make_conversation(mode)
conv.app = MagicMock()
mock_workflow_run = MagicMock()
mock_workflow_run.workflow_id = str(uuid4())
mock_workflow = MagicMock()
mock_workflow.features_dict = {}
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
mem._workflow_run_repo = MagicMock()
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalar.return_value = mock_workflow
result = mem._build_prompt_message_with_files(
message_files=[],
text_content="wf text",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
assert isinstance(result, UserPromptMessage)
assert result.content == "wf text"
# ------------------------------------------------------------------
# Invalid mode
# ------------------------------------------------------------------
def test_invalid_mode_raises_assertion(self):
"""Any unknown AppMode raises AssertionError."""
conv = _make_conversation()
conv.mode = "unknown_mode" # not in any set
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
with pytest.raises(AssertionError, match="Invalid app mode"):
mem._build_prompt_message_with_files(
message_files=[],
text_content="text",
message=_make_message(),
app_record=MagicMock(),
is_user_message=True,
)
# ===========================================================================
# Tests for get_history_prompt_messages
# ===========================================================================
class TestGetHistoryPromptMessages:
"""Tests for get_history_prompt_messages."""
def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory:
conv = _make_conversation(mode)
conv.app = MagicMock()
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
def test_returns_empty_when_no_messages(self):
mem = self._make_memory()
with patch("core.memory.token_buffer_memory.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = []
result = mem.get_history_prompt_messages()
assert result == []
def test_skips_first_message_without_answer(self):
"""The newest message (index 0 after extraction) without answer and tokens==0 is skipped."""
mem = self._make_memory()
msg_no_answer = _make_message(answer="", answer_tokens=0)
msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg_no_answer],
),
):
mock_db.session.scalars.return_value.all.side_effect = [
[msg_no_answer], # first call: messages query
[], # second call: user files query (never hit, but safe)
]
result = mem.get_history_prompt_messages()
assert result == []
def test_message_with_answer_not_skipped(self):
"""A message with a non-empty answer is NOT popped."""
mem = self._make_memory()
msg = _make_message(answer="some answer", answer_tokens=10)
msg.parent_message_id = None
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
# user files query → empty; assistant files query → empty
mock_db.session.scalars.return_value.all.return_value = []
result = mem.get_history_prompt_messages()
assert len(result) == 2 # one user + one assistant
def test_message_limit_default_is_500(self):
"""When message_limit is None the stmt is limited to 500."""
mem = self._make_memory()
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch("core.memory.token_buffer_memory.select") as mock_select,
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
):
mock_stmt = MagicMock()
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_db.session.scalars.return_value.all.return_value = []
mem.get_history_prompt_messages(message_limit=None)
mock_stmt.limit.assert_called_with(500)
def test_message_limit_clipped_to_500(self):
"""A message_limit > 500 is clamped to 500."""
mem = self._make_memory()
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch("core.memory.token_buffer_memory.select") as mock_select,
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
):
mock_stmt = MagicMock()
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_db.session.scalars.return_value.all.return_value = []
mem.get_history_prompt_messages(message_limit=9999)
mock_stmt.limit.assert_called_with(500)
def test_message_limit_positive_used(self):
"""A positive message_limit < 500 is used as-is."""
mem = self._make_memory()
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch("core.memory.token_buffer_memory.select") as mock_select,
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
):
mock_stmt = MagicMock()
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_db.session.scalars.return_value.all.return_value = []
mem.get_history_prompt_messages(message_limit=10)
mock_stmt.limit.assert_called_with(10)
def test_message_limit_zero_uses_default(self):
"""message_limit=0 triggers the else branch → default 500."""
mem = self._make_memory()
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch("core.memory.token_buffer_memory.select") as mock_select,
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
):
mock_stmt = MagicMock()
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_db.session.scalars.return_value.all.return_value = []
mem.get_history_prompt_messages(message_limit=0)
mock_stmt.limit.assert_called_with(500)
def test_user_files_cause_build_with_files_call(self):
"""When user_files is non-empty _build_prompt_message_with_files is invoked."""
mem = self._make_memory()
msg = _make_message()
msg.parent_message_id = None
mock_user_file = MagicMock()
mock_user_prompt = UserPromptMessage(content="from build")
mock_assistant_prompt = AssistantPromptMessage(content="answer")
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
# messages query
r.all.return_value = [msg]
elif call_count["n"] == 1:
# user files
r.all.return_value = [mock_user_file]
else:
# assistant files
r.all.return_value = []
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch.object(
mem,
"_build_prompt_message_with_files",
side_effect=[mock_user_prompt, mock_assistant_prompt],
) as mock_build,
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages()
assert mock_build.call_count >= 1
# First call should be user message
first_call_kwargs = mock_build.call_args_list[0][1]
assert first_call_kwargs["is_user_message"] is True
def test_assistant_files_cause_build_with_files_call(self):
"""When assistant_files is non-empty, build is called with is_user_message=False."""
mem = self._make_memory()
msg = _make_message()
msg.parent_message_id = None
mock_assistant_file = MagicMock()
mock_user_prompt = UserPromptMessage(content="query")
mock_assistant_prompt = AssistantPromptMessage(content="built")
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
r.all.return_value = [msg]
elif call_count["n"] == 1:
r.all.return_value = [] # no user files
else:
r.all.return_value = [mock_assistant_file]
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch.object(
mem,
"_build_prompt_message_with_files",
return_value=mock_assistant_prompt,
) as mock_build,
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages()
mock_build.assert_called_once()
call_kwargs = mock_build.call_args[1]
assert call_kwargs["is_user_message"] is False
def test_token_pruning_removes_oldest_messages(self):
"""If tokens exceed limit, oldest messages are removed until within limit."""
conv = _make_conversation()
conv.app = MagicMock()
# Model returns tokens that decrease only after removing pairs
token_values = [3000, 1500] # first call over limit, second within
mi = MagicMock()
mi.get_llm_num_tokens.side_effect = token_values
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
msg = _make_message()
msg.parent_message_id = None
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
r.all.return_value = [msg]
else:
r.all.return_value = []
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages(max_token_limit=2000)
# After pruning, we should have fewer than the 2 initial messages
assert len(result) <= 1
def test_token_pruning_stops_at_single_message(self):
"""Pruning stops when only 1 message remains (to prevent empty list)."""
conv = _make_conversation()
conv.app = MagicMock()
# Always over limit
mi = MagicMock()
mi.get_llm_num_tokens.return_value = 99999
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
msg = _make_message()
msg.parent_message_id = None
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
r.all.return_value = [msg]
else:
r.all.return_value = []
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages(max_token_limit=1)
# At least 1 message should remain
assert len(result) >= 1
def test_no_pruning_when_within_limit(self):
"""When tokens ≤ limit, no pruning occurs."""
mem = self._make_memory()
mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000
msg = _make_message()
msg.parent_message_id = None
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
r.all.return_value = [msg]
else:
r.all.return_value = []
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages(max_token_limit=2000)
assert len(result) == 2 # user + assistant
def test_plain_user_and_assistant_messages_returned(self):
"""Without files, plain UserPromptMessage and AssistantPromptMessage appear."""
mem = self._make_memory()
msg = _make_message(answer="My answer")
msg.query = "My query"
msg.parent_message_id = None
call_count = {"n": 0}
def scalars_side_effect(stmt):
r = MagicMock()
if call_count["n"] == 0:
r.all.return_value = [msg]
else:
r.all.return_value = []
call_count["n"] += 1
return r
with (
patch("core.memory.token_buffer_memory.db") as mock_db,
patch(
"core.memory.token_buffer_memory.extract_thread_messages",
return_value=[msg],
),
patch(
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
return_value=None,
),
):
mock_db.session.scalars.side_effect = scalars_side_effect
result = mem.get_history_prompt_messages()
assert len(result) == 2
user_msg, ai_msg = result
assert isinstance(user_msg, UserPromptMessage)
assert user_msg.content == "My query"
assert isinstance(ai_msg, AssistantPromptMessage)
assert ai_msg.content == "My answer"
# ===========================================================================
# Tests for get_history_prompt_text
# ===========================================================================
class TestGetHistoryPromptText:
"""Tests for get_history_prompt_text."""
def _make_memory(self) -> TokenBufferMemory:
conv = _make_conversation()
conv.app = MagicMock()
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
def test_empty_messages_returns_empty_string(self):
mem = self._make_memory()
with patch.object(mem, "get_history_prompt_messages", return_value=[]):
result = mem.get_history_prompt_text()
assert result == ""
def test_user_and_assistant_messages_formatted(self):
mem = self._make_memory()
messages = [
UserPromptMessage(content="Hello"),
AssistantPromptMessage(content="World"),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A")
assert result == "H: Hello\nA: World"
def test_custom_prefixes_applied(self):
mem = self._make_memory()
messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Bye"),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot")
assert "Human: Hi" in result
assert "Bot: Bye" in result
def test_list_content_with_text_and_image(self):
"""List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image]."""
mem = self._make_memory()
messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="caption"),
ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"),
]
),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text()
assert "caption" in result
assert "[image]" in result
def test_list_content_text_only(self):
mem = self._make_memory()
messages = [
UserPromptMessage(content=[TextPromptMessageContent(data="just text")]),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text()
assert "just text" in result
def test_list_content_image_only(self):
mem = self._make_memory()
messages = [
UserPromptMessage(
content=[
ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"),
]
),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text()
assert "[image]" in result
def test_unknown_role_skipped(self):
"""Messages with a role that is not USER or ASSISTANT are skipped."""
mem = self._make_memory()
# Create a mock message with a SYSTEM role
system_msg = MagicMock()
system_msg.role = PromptMessageRole.SYSTEM
system_msg.content = "system instruction"
user_msg = UserPromptMessage(content="hi")
with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]):
result = mem.get_history_prompt_text()
assert "system instruction" not in result
assert "Human: hi" in result
def test_passes_max_token_limit_and_message_limit(self):
"""Parameters are forwarded to get_history_prompt_messages."""
mem = self._make_memory()
with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get:
mem.get_history_prompt_text(max_token_limit=500, message_limit=10)
mock_get.assert_called_once_with(max_token_limit=500, message_limit=10)
def test_multiple_messages_joined_by_newline(self):
mem = self._make_memory()
messages = [
UserPromptMessage(content="Q1"),
AssistantPromptMessage(content="A1"),
UserPromptMessage(content="Q2"),
AssistantPromptMessage(content="A2"),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text()
lines = result.split("\n")
assert len(lines) == 4
assert lines[0] == "Human: Q1"
assert lines[1] == "Assistant: A1"
assert lines[2] == "Human: Q2"
assert lines[3] == "Assistant: A2"
def test_assistant_list_content_formatted(self):
"""AssistantPromptMessage with list content is also handled."""
mem = self._make_memory()
messages = [
AssistantPromptMessage(
content=[
TextPromptMessageContent(data="response text"),
ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"),
]
),
]
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
result = mem.get_history_prompt_text()
assert "response text" in result
assert "[image]" in result

View File

@ -0,0 +1,964 @@
"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py"""
from unittest.mock import MagicMock, patch
import pytest
from dify_graph.model_runtime.callbacks.base_callback import (
_TEXT_COLOR_MAPPING,
Callback,
)
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
# ---------------------------------------------------------------------------
# Concrete implementation of the abstract Callback for testing
# ---------------------------------------------------------------------------
class ConcreteCallback(Callback):
"""A minimal concrete subclass that satisfies all abstract methods."""
def __init__(self, raise_error: bool = False):
self.raise_error = raise_error
# Track invocations
self.before_invoke_calls: list[dict] = []
self.new_chunk_calls: list[dict] = []
self.after_invoke_calls: list[dict] = []
self.invoke_error_calls: list[dict] = []
def on_before_invoke(
self,
llm_instance,
model,
credentials,
prompt_messages,
model_parameters,
tools=None,
stop=None,
stream=True,
user=None,
):
self.before_invoke_calls.append(
{
"llm_instance": llm_instance,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
"user": user,
}
)
# To cover the 'raise NotImplementedError()' in the base class
try:
super().on_before_invoke(
llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)
except NotImplementedError:
pass
def on_new_chunk(
self,
llm_instance,
chunk,
model,
credentials,
prompt_messages,
model_parameters,
tools=None,
stop=None,
stream=True,
user=None,
):
self.new_chunk_calls.append(
{
"llm_instance": llm_instance,
"chunk": chunk,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
"user": user,
}
)
try:
super().on_new_chunk(
llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)
except NotImplementedError:
pass
def on_after_invoke(
self,
llm_instance,
result,
model,
credentials,
prompt_messages,
model_parameters,
tools=None,
stop=None,
stream=True,
user=None,
):
self.after_invoke_calls.append(
{
"llm_instance": llm_instance,
"result": result,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
"user": user,
}
)
try:
super().on_after_invoke(
llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)
except NotImplementedError:
pass
def on_invoke_error(
self,
llm_instance,
ex,
model,
credentials,
prompt_messages,
model_parameters,
tools=None,
stop=None,
stream=True,
user=None,
):
self.invoke_error_calls.append(
{
"llm_instance": llm_instance,
"ex": ex,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
"user": user,
}
)
try:
super().on_invoke_error(
llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)
except NotImplementedError:
pass
# ---------------------------------------------------------------------------
# A subclass that deliberately leaves abstract methods un-implemented,
# used to verify that instantiation raises TypeError.
# ---------------------------------------------------------------------------
# ===========================================================================
# Tests for _TEXT_COLOR_MAPPING module-level constant
# ===========================================================================
class TestTextColorMapping:
"""Tests for the module-level _TEXT_COLOR_MAPPING dictionary."""
def test_contains_all_expected_colors(self):
expected_keys = {"blue", "yellow", "pink", "green", "red"}
assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys
def test_blue_escape_code(self):
assert _TEXT_COLOR_MAPPING["blue"] == "36;1"
def test_yellow_escape_code(self):
assert _TEXT_COLOR_MAPPING["yellow"] == "33;1"
def test_pink_escape_code(self):
assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200"
def test_green_escape_code(self):
assert _TEXT_COLOR_MAPPING["green"] == "32;1"
def test_red_escape_code(self):
assert _TEXT_COLOR_MAPPING["red"] == "31;1"
def test_mapping_is_dict(self):
assert isinstance(_TEXT_COLOR_MAPPING, dict)
def test_all_values_are_strings(self):
for key, value in _TEXT_COLOR_MAPPING.items():
assert isinstance(value, str), f"Value for {key!r} should be str"
# ===========================================================================
# Tests for the Callback ABC itself
# ===========================================================================
class TestCallbackAbstract:
"""Tests verifying Callback is a proper ABC."""
def test_cannot_instantiate_abstract_class_directly(self):
"""Callback cannot be instantiated since it has abstract methods."""
with pytest.raises(TypeError):
Callback() # type: ignore[abstract]
def test_concrete_subclass_can_be_instantiated(self):
cb = ConcreteCallback()
assert isinstance(cb, Callback)
def test_default_raise_error_is_false(self):
cb = ConcreteCallback()
assert cb.raise_error is False
def test_raise_error_can_be_set_to_true(self):
cb = ConcreteCallback(raise_error=True)
assert cb.raise_error is True
def test_subclass_missing_on_before_invoke_raises_type_error(self):
"""A subclass missing any single abstract method cannot be instantiated."""
class IncompleteCallback(Callback):
def on_new_chunk(self, *a, **kw): ...
def on_after_invoke(self, *a, **kw): ...
def on_invoke_error(self, *a, **kw): ...
with pytest.raises(TypeError):
IncompleteCallback() # type: ignore[abstract]
def test_subclass_missing_on_new_chunk_raises_type_error(self):
class IncompleteCallback(Callback):
def on_before_invoke(self, *a, **kw): ...
def on_after_invoke(self, *a, **kw): ...
def on_invoke_error(self, *a, **kw): ...
with pytest.raises(TypeError):
IncompleteCallback() # type: ignore[abstract]
def test_subclass_missing_on_after_invoke_raises_type_error(self):
class IncompleteCallback(Callback):
def on_before_invoke(self, *a, **kw): ...
def on_new_chunk(self, *a, **kw): ...
def on_invoke_error(self, *a, **kw): ...
with pytest.raises(TypeError):
IncompleteCallback() # type: ignore[abstract]
def test_subclass_missing_on_invoke_error_raises_type_error(self):
class IncompleteCallback(Callback):
def on_before_invoke(self, *a, **kw): ...
def on_new_chunk(self, *a, **kw): ...
def on_after_invoke(self, *a, **kw): ...
with pytest.raises(TypeError):
IncompleteCallback() # type: ignore[abstract]
# ===========================================================================
# Tests for on_before_invoke
# ===========================================================================
class TestOnBeforeInvoke:
"""Tests for the on_before_invoke callback method."""
def setup_method(self):
self.cb = ConcreteCallback()
self.llm_instance = MagicMock()
self.model = "gpt-4"
self.credentials = {"api_key": "sk-test"}
self.prompt_messages = [MagicMock(spec=PromptMessage)]
self.model_parameters = {"temperature": 0.7}
def test_on_before_invoke_called_with_required_args(self):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.before_invoke_calls) == 1
call = self.cb.before_invoke_calls[0]
assert call["llm_instance"] is self.llm_instance
assert call["model"] == self.model
assert call["credentials"] == self.credentials
assert call["prompt_messages"] is self.prompt_messages
assert call["model_parameters"] is self.model_parameters
def test_on_before_invoke_defaults_tools_none(self):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.before_invoke_calls[0]["tools"] is None
def test_on_before_invoke_defaults_stop_none(self):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.before_invoke_calls[0]["stop"] is None
def test_on_before_invoke_defaults_stream_true(self):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.before_invoke_calls[0]["stream"] is True
def test_on_before_invoke_defaults_user_none(self):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.before_invoke_calls[0]["user"] is None
def test_on_before_invoke_with_all_optional_args(self):
tools = [MagicMock(spec=PromptMessageTool)]
stop = ["stop1", "stop2"]
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
tools=tools,
stop=stop,
stream=False,
user="user-123",
)
call = self.cb.before_invoke_calls[0]
assert call["tools"] is tools
assert call["stop"] == stop
assert call["stream"] is False
assert call["user"] == "user-123"
def test_on_before_invoke_called_multiple_times(self):
for i in range(3):
self.cb.on_before_invoke(
llm_instance=self.llm_instance,
model=f"model-{i}",
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.before_invoke_calls) == 3
assert self.cb.before_invoke_calls[2]["model"] == "model-2"
# ===========================================================================
# Tests for on_new_chunk
# ===========================================================================
class TestOnNewChunk:
"""Tests for the on_new_chunk callback method."""
def setup_method(self):
self.cb = ConcreteCallback()
self.llm_instance = MagicMock()
self.chunk = MagicMock(spec=LLMResultChunk)
self.model = "gpt-3.5-turbo"
self.credentials = {"api_key": "sk-test"}
self.prompt_messages = [MagicMock(spec=PromptMessage)]
self.model_parameters = {"max_tokens": 256}
def test_on_new_chunk_called_with_required_args(self):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.new_chunk_calls) == 1
call = self.cb.new_chunk_calls[0]
assert call["llm_instance"] is self.llm_instance
assert call["chunk"] is self.chunk
assert call["model"] == self.model
assert call["credentials"] == self.credentials
def test_on_new_chunk_defaults_tools_none(self):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.new_chunk_calls[0]["tools"] is None
def test_on_new_chunk_defaults_stop_none(self):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.new_chunk_calls[0]["stop"] is None
def test_on_new_chunk_defaults_stream_true(self):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.new_chunk_calls[0]["stream"] is True
def test_on_new_chunk_defaults_user_none(self):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.new_chunk_calls[0]["user"] is None
def test_on_new_chunk_with_all_optional_args(self):
tools = [MagicMock(spec=PromptMessageTool)]
stop = ["END"]
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
tools=tools,
stop=stop,
stream=False,
user="chunk-user",
)
call = self.cb.new_chunk_calls[0]
assert call["tools"] is tools
assert call["stop"] == stop
assert call["stream"] is False
assert call["user"] == "chunk-user"
def test_on_new_chunk_called_multiple_times(self):
for i in range(5):
self.cb.on_new_chunk(
llm_instance=self.llm_instance,
chunk=self.chunk,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.new_chunk_calls) == 5
# ===========================================================================
# Tests for on_after_invoke
# ===========================================================================
class TestOnAfterInvoke:
"""Tests for the on_after_invoke callback method."""
def setup_method(self):
self.cb = ConcreteCallback()
self.llm_instance = MagicMock()
self.result = MagicMock(spec=LLMResult)
self.model = "claude-3"
self.credentials = {"api_key": "anthropic-key"}
self.prompt_messages = [MagicMock(spec=PromptMessage)]
self.model_parameters = {"temperature": 1.0}
def test_on_after_invoke_called_with_required_args(self):
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.after_invoke_calls) == 1
call = self.cb.after_invoke_calls[0]
assert call["llm_instance"] is self.llm_instance
assert call["result"] is self.result
assert call["model"] == self.model
assert call["credentials"] is self.credentials
def test_on_after_invoke_defaults_tools_none(self):
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.after_invoke_calls[0]["tools"] is None
def test_on_after_invoke_defaults_stop_none(self):
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.after_invoke_calls[0]["stop"] is None
def test_on_after_invoke_defaults_stream_true(self):
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.after_invoke_calls[0]["stream"] is True
def test_on_after_invoke_defaults_user_none(self):
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.after_invoke_calls[0]["user"] is None
def test_on_after_invoke_with_all_optional_args(self):
tools = [MagicMock(spec=PromptMessageTool)]
stop = ["STOP"]
self.cb.on_after_invoke(
llm_instance=self.llm_instance,
result=self.result,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
tools=tools,
stop=stop,
stream=False,
user="after-user",
)
call = self.cb.after_invoke_calls[0]
assert call["tools"] is tools
assert call["stop"] == stop
assert call["stream"] is False
assert call["user"] == "after-user"
# ===========================================================================
# Tests for on_invoke_error
# ===========================================================================
class TestOnInvokeError:
"""Tests for the on_invoke_error callback method."""
def setup_method(self):
self.cb = ConcreteCallback()
self.llm_instance = MagicMock()
self.ex = ValueError("something went wrong")
self.model = "gemini-pro"
self.credentials = {"api_key": "google-key"}
self.prompt_messages = [MagicMock(spec=PromptMessage)]
self.model_parameters = {"top_p": 0.9}
def test_on_invoke_error_called_with_required_args(self):
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.invoke_error_calls) == 1
call = self.cb.invoke_error_calls[0]
assert call["llm_instance"] is self.llm_instance
assert call["ex"] is self.ex
assert call["model"] == self.model
assert call["credentials"] is self.credentials
def test_on_invoke_error_defaults_tools_none(self):
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.invoke_error_calls[0]["tools"] is None
def test_on_invoke_error_defaults_stop_none(self):
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.invoke_error_calls[0]["stop"] is None
def test_on_invoke_error_defaults_stream_true(self):
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.invoke_error_calls[0]["stream"] is True
def test_on_invoke_error_defaults_user_none(self):
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert self.cb.invoke_error_calls[0]["user"] is None
def test_on_invoke_error_with_all_optional_args(self):
tools = [MagicMock(spec=PromptMessageTool)]
stop = ["HALT"]
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=self.ex,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
tools=tools,
stop=stop,
stream=False,
user="error-user",
)
call = self.cb.invoke_error_calls[0]
assert call["tools"] is tools
assert call["stop"] == stop
assert call["stream"] is False
assert call["user"] == "error-user"
def test_on_invoke_error_accepts_various_exception_types(self):
for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]:
self.cb.on_invoke_error(
llm_instance=self.llm_instance,
ex=exc,
model=self.model,
credentials=self.credentials,
prompt_messages=self.prompt_messages,
model_parameters=self.model_parameters,
)
assert len(self.cb.invoke_error_calls) == 3
# ===========================================================================
# Tests for print_text (concrete method on Callback)
# ===========================================================================
class TestPrintText:
"""Tests for the concrete print_text method."""
def setup_method(self):
self.cb = ConcreteCallback()
def test_print_text_without_color_prints_plain_text(self, capsys):
self.cb.print_text("hello world")
captured = capsys.readouterr()
assert captured.out == "hello world"
def test_print_text_with_color_prints_colored_text(self, capsys):
self.cb.print_text("colored text", color="blue")
captured = capsys.readouterr()
# Should contain ANSI escape sequences
assert "colored text" in captured.out
assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out
def test_print_text_without_color_no_ansi(self, capsys):
self.cb.print_text("plain text", color=None)
captured = capsys.readouterr()
assert captured.out == "plain text"
# No ANSI escape sequences
assert "\x1b" not in captured.out
def test_print_text_default_end_is_empty_string(self, capsys):
self.cb.print_text("no newline")
captured = capsys.readouterr()
assert not captured.out.endswith("\n")
def test_print_text_with_custom_end(self, capsys):
self.cb.print_text("with newline", end="\n")
captured = capsys.readouterr()
assert captured.out.endswith("\n")
def test_print_text_with_empty_string(self, capsys):
self.cb.print_text("", color=None)
captured = capsys.readouterr()
assert captured.out == ""
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
def test_print_text_all_colors_work(self, color, capsys):
"""Verify no KeyError is thrown for any valid color."""
self.cb.print_text("test", color=color)
captured = capsys.readouterr()
assert "test" in captured.out
def test_print_text_calls_get_colored_text_when_color_given(self):
with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct:
with patch("builtins.print") as mock_print:
self.cb.print_text("hello", color="green")
mock_gct.assert_called_once_with("hello", "green")
mock_print.assert_called_once_with("[COLORED]", end="")
def test_print_text_does_not_call_get_colored_text_when_no_color(self):
with patch.object(self.cb, "_get_colored_text") as mock_gct:
with patch("builtins.print"):
self.cb.print_text("hello", color=None)
mock_gct.assert_not_called()
def test_print_text_passes_end_to_print(self):
with patch("builtins.print") as mock_print:
self.cb.print_text("text", end="---")
mock_print.assert_called_once_with("text", end="---")
# ===========================================================================
# Tests for _get_colored_text (private helper method)
# ===========================================================================
class TestGetColoredText:
"""Tests for the _get_colored_text private method."""
def setup_method(self):
self.cb = ConcreteCallback()
@pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items()))
def test_get_colored_text_uses_correct_escape_code(self, color, expected_code):
result = self.cb._get_colored_text("text", color)
assert expected_code in result
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
def test_get_colored_text_contains_input_text(self, color):
result = self.cb._get_colored_text("hello", color)
assert "hello" in result
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
def test_get_colored_text_starts_with_escape(self, color):
result = self.cb._get_colored_text("text", color)
# Should start with an ANSI escape (\x1b or \u001b)
assert result.startswith("\x1b[") or result.startswith("\u001b[")
@pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"])
def test_get_colored_text_ends_with_reset(self, color):
result = self.cb._get_colored_text("text", color)
# Should end with the ANSI reset code
assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m")
def test_get_colored_text_returns_string(self):
result = self.cb._get_colored_text("text", "blue")
assert isinstance(result, str)
def test_get_colored_text_blue_exact_format(self):
result = self.cb._get_colored_text("hello", "blue")
expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m"
assert result == expected
def test_get_colored_text_red_exact_format(self):
result = self.cb._get_colored_text("error", "red")
expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m"
assert result == expected
def test_get_colored_text_green_exact_format(self):
result = self.cb._get_colored_text("ok", "green")
expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m"
assert result == expected
def test_get_colored_text_yellow_exact_format(self):
result = self.cb._get_colored_text("warn", "yellow")
expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m"
assert result == expected
def test_get_colored_text_pink_exact_format(self):
result = self.cb._get_colored_text("info", "pink")
expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m"
assert result == expected
def test_get_colored_text_empty_string(self):
result = self.cb._get_colored_text("", "blue")
assert isinstance(result, str)
# Empty text should still have escape codes
assert _TEXT_COLOR_MAPPING["blue"] in result
def test_get_colored_text_invalid_color_raises_key_error(self):
with pytest.raises(KeyError):
self.cb._get_colored_text("text", "purple")
def test_get_colored_text_with_special_characters(self):
special = "hello\nworld\ttab"
result = self.cb._get_colored_text(special, "blue")
assert special in result
def test_get_colored_text_with_long_text(self):
long_text = "a" * 10000
result = self.cb._get_colored_text(long_text, "green")
assert long_text in result
# ===========================================================================
# Integration-style tests: full workflow through a ConcreteCallback
# ===========================================================================
class TestConcreteCallbackIntegration:
"""End-to-end workflow tests using ConcreteCallback."""
def test_full_invocation_lifecycle(self):
"""Simulate a complete LLM invocation lifecycle through all callbacks."""
cb = ConcreteCallback()
llm_instance = MagicMock()
model = "gpt-4o"
credentials = {"api_key": "sk-xyz"}
prompt_messages = [MagicMock(spec=PromptMessage)]
model_parameters = {"temperature": 0.5}
tools = [MagicMock(spec=PromptMessageTool)]
stop = ["<END>"]
user = "user-abc"
# 1. Before invoke
cb.on_before_invoke(
llm_instance=llm_instance,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=True,
user=user,
)
# 2. Multiple chunks during streaming
for i in range(3):
chunk = MagicMock(spec=LLMResultChunk)
cb.on_new_chunk(
llm_instance=llm_instance,
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=True,
user=user,
)
# 3. After invoke
result = MagicMock(spec=LLMResult)
cb.on_after_invoke(
llm_instance=llm_instance,
result=result,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=True,
user=user,
)
assert len(cb.before_invoke_calls) == 1
assert len(cb.new_chunk_calls) == 3
assert len(cb.after_invoke_calls) == 1
assert len(cb.invoke_error_calls) == 0
def test_error_lifecycle(self):
"""Simulate an invoke that results in an error."""
cb = ConcreteCallback()
llm_instance = MagicMock()
model = "gpt-4"
credentials = {}
prompt_messages = []
model_parameters = {}
cb.on_before_invoke(
llm_instance=llm_instance,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
)
ex = RuntimeError("API timeout")
cb.on_invoke_error(
llm_instance=llm_instance,
ex=ex,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
)
assert len(cb.before_invoke_calls) == 1
assert len(cb.invoke_error_calls) == 1
assert cb.invoke_error_calls[0]["ex"] is ex
assert len(cb.after_invoke_calls) == 0
def test_print_text_with_color_in_integration(self, capsys):
"""verify print_text works correctly in a concrete instance."""
cb = ConcreteCallback()
cb.print_text("SUCCESS", color="green", end="\n")
captured = capsys.readouterr()
assert "SUCCESS" in captured.out
assert "\n" in captured.out
def test_print_text_no_color_in_integration(self, capsys):
cb = ConcreteCallback()
cb.print_text("plain output")
captured = capsys.readouterr()
assert captured.out == "plain output"

View File

@ -0,0 +1,700 @@
"""
Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py
Coverage targets:
- LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream,
prompt_message.name, model_parameters)
- LoggingCallback.on_new_chunk (writes to stdout)
- LoggingCallback.on_after_invoke (all branches: tool_calls present / absent)
- LoggingCallback.on_invoke_error (logs exception via logger.exception)
"""
from __future__ import annotations
import json
from collections.abc import Sequence
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _make_usage() -> LLMUsage:
"""Return a minimal LLMUsage instance."""
return LLMUsage(
prompt_tokens=10,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.01"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.002"),
completion_price=Decimal("0.04"),
total_tokens=30,
total_price=Decimal("0.05"),
currency="USD",
latency=0.5,
)
def _make_llm_result(
content: str = "hello world",
tool_calls: list | None = None,
model: str = "gpt-4",
system_fingerprint: str | None = "fp-abc",
) -> LLMResult:
"""Return an LLMResult with an AssistantPromptMessage."""
assistant_msg = AssistantPromptMessage(
content=content,
tool_calls=tool_calls or [],
)
return LLMResult(
model=model,
message=assistant_msg,
usage=_make_usage(),
system_fingerprint=system_fingerprint,
)
def _make_chunk(content: str = "chunk-text") -> LLMResultChunk:
"""Return a minimal LLMResultChunk."""
return LLMResultChunk(
model="gpt-4",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content),
),
)
def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage:
return UserPromptMessage(content=content, name=name)
def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage:
return SystemPromptMessage(content=content)
def _make_tool(name: str = "my_tool") -> PromptMessageTool:
return PromptMessageTool(name=name, description="A tool", parameters={})
def _make_tool_call(
call_id: str = "call-1",
func_name: str = "some_func",
arguments: str = '{"key": "value"}',
) -> AssistantPromptMessage.ToolCall:
return AssistantPromptMessage.ToolCall(
id=call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments),
)
# ---------------------------------------------------------------------------
# Fixture: shared LoggingCallback instance (no heavy state)
# ---------------------------------------------------------------------------
@pytest.fixture
def cb() -> LoggingCallback:
return LoggingCallback()
@pytest.fixture
def llm_instance() -> MagicMock:
return MagicMock()
# ===========================================================================
# Tests for on_before_invoke
# ===========================================================================
class TestOnBeforeInvoke:
"""Tests for LoggingCallback.on_before_invoke."""
def _invoke(
self,
cb: LoggingCallback,
llm_instance: MagicMock,
*,
model: str = "gpt-4",
credentials: dict | None = None,
prompt_messages: list | None = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
cb.on_before_invoke(
llm_instance=llm_instance,
model=model,
credentials=credentials or {},
prompt_messages=prompt_messages or [],
model_parameters=model_parameters or {},
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Calling with bare-minimum args should not raise."""
self._invoke(cb, llm_instance)
def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""The model name must appear in print_text calls."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, model="claude-3")
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "claude-3" in calls_text
def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Each key-value pair of model_parameters must be printed."""
params = {"temperature": 0.7, "max_tokens": 512}
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, model_parameters=params)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "temperature" in calls_text
assert "0.7" in calls_text
assert "max_tokens" in calls_text
assert "512" in calls_text
def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Empty model_parameters dict should not raise."""
self._invoke(cb, llm_instance, model_parameters={})
# ------------------------------------------------------------------
# stop branch
# ------------------------------------------------------------------
def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
"""stop words must appear in output when provided."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, stop=["STOP", "END"])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "stop" in calls_text
def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When stop=None the stop line must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, stop=None)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "\tstop:" not in calls_text
def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When stop=[] (falsy) the stop line must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, stop=[])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "\tstop:" not in calls_text
# ------------------------------------------------------------------
# tools branch
# ------------------------------------------------------------------
def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Tool names must appear in output when tools are provided."""
tools = [_make_tool("search"), _make_tool("calculate")]
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, tools=tools)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "search" in calls_text
assert "calculate" in calls_text
def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When tools=None the Tools section must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, tools=None)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "Tools:" not in calls_text
def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When tools=[] (falsy) the Tools section must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, tools=[])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "Tools:" not in calls_text
# ------------------------------------------------------------------
# user branch
# ------------------------------------------------------------------
def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock):
"""User string must appear in output when provided."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, user="alice")
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "alice" in calls_text
def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When user=None the User line must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, user=None)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "User:" not in calls_text
# ------------------------------------------------------------------
# stream branch
# ------------------------------------------------------------------
def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When stream=True the [on_llm_new_chunk] marker must be printed."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, stream=True)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "[on_llm_new_chunk]" in calls_text
def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When stream=False the [on_llm_new_chunk] marker must NOT appear."""
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, stream=False)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "[on_llm_new_chunk]" not in calls_text
# ------------------------------------------------------------------
# prompt_messages branch
# ------------------------------------------------------------------
def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When a PromptMessage has a name it must be printed."""
msg = _make_user_prompt("hi", name="bob")
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, prompt_messages=[msg])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "bob" in calls_text
def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When a PromptMessage has no name the name line must NOT appear."""
msg = _make_user_prompt("hi", name=None)
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, prompt_messages=[msg])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "\tname:" not in calls_text
def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Role and content of each PromptMessage must appear in output."""
msg = _make_system_prompt("Be concise.")
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, prompt_messages=[msg])
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "system" in calls_text
assert "Be concise." in calls_text
def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""All entries in prompt_messages are iterated and printed."""
msgs = [
_make_system_prompt("sys"),
_make_user_prompt("user msg"),
]
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, prompt_messages=msgs)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "sys" in calls_text
assert "user msg" in calls_text
# ------------------------------------------------------------------
# Combination: everything provided
# ------------------------------------------------------------------
def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Supply stop, tools, user, multiple params, named message no exception."""
msgs = [_make_user_prompt("question", name="alice")]
tools = [_make_tool("tool_a")]
with patch.object(cb, "print_text"):
self._invoke(
cb,
llm_instance,
model="gpt-3.5",
model_parameters={"temperature": 1.0, "top_p": 0.9},
tools=tools,
stop=["DONE"],
stream=True,
user="alice",
prompt_messages=msgs,
)
# ===========================================================================
# Tests for on_new_chunk
# ===========================================================================
class TestOnNewChunk:
"""Tests for LoggingCallback.on_new_chunk."""
def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock):
"""on_new_chunk must write the chunk's text content to sys.stdout."""
chunk = _make_chunk("hello from LLM")
written = []
with patch("sys.stdout") as mock_stdout:
mock_stdout.write.side_effect = written.append
cb.on_new_chunk(
llm_instance=llm_instance,
chunk=chunk,
model="gpt-4",
credentials={},
prompt_messages=[],
model_parameters={},
)
mock_stdout.write.assert_called_once_with("hello from LLM")
mock_stdout.flush.assert_called_once()
def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Works correctly even when the chunk content is an empty string."""
chunk = _make_chunk("")
with patch("sys.stdout") as mock_stdout:
cb.on_new_chunk(
llm_instance=llm_instance,
chunk=chunk,
model="gpt-4",
credentials={},
prompt_messages=[],
model_parameters={},
)
mock_stdout.write.assert_called_once_with("")
mock_stdout.flush.assert_called_once()
def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock):
"""All optional parameters are accepted without errors."""
chunk = _make_chunk("data")
with patch("sys.stdout"):
cb.on_new_chunk(
llm_instance=llm_instance,
chunk=chunk,
model="gpt-4",
credentials={"key": "secret"},
prompt_messages=[_make_user_prompt("q")],
model_parameters={"temperature": 0.5},
tools=[_make_tool("t1")],
stop=["EOS"],
stream=True,
user="bob",
)
# ===========================================================================
# Tests for on_after_invoke
# ===========================================================================
class TestOnAfterInvoke:
"""Tests for LoggingCallback.on_after_invoke."""
def _invoke(
self,
cb: LoggingCallback,
llm_instance: MagicMock,
result: LLMResult,
**kwargs,
):
cb.on_after_invoke(
llm_instance=llm_instance,
result=result,
model=result.model,
credentials={},
prompt_messages=[],
model_parameters={},
**kwargs,
)
def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""After-invoke header, content, model, usage, fingerprint must be printed."""
result = _make_llm_result()
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "[on_llm_after_invoke]" in calls_text
assert "hello world" in calls_text
assert "gpt-4" in calls_text
assert "fp-abc" in calls_text
def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When there are no tool_calls the 'Tool calls:' block must NOT appear."""
result = _make_llm_result(tool_calls=[])
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "Tool calls:" not in calls_text
def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When tool_calls exist their id, name, and JSON arguments must be printed."""
tc = _make_tool_call(
call_id="call-xyz",
func_name="fetch_data",
arguments='{"url": "https://example.com"}',
)
result = _make_llm_result(tool_calls=[tc])
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "Tool calls:" in calls_text
assert "call-xyz" in calls_text
assert "fetch_data" in calls_text
# arguments should be JSON-dumped
assert "https://example.com" in calls_text
def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""All tool calls in the list must be iterated."""
tcs = [
_make_tool_call("id-1", "func_a", '{"a": 1}'),
_make_tool_call("id-2", "func_b", '{"b": 2}'),
]
result = _make_llm_result(tool_calls=tcs)
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "id-1" in calls_text
assert "func_a" in calls_text
assert "id-2" in calls_text
assert "func_b" in calls_text
def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""When system_fingerprint is None it should still be printed (as None)."""
result = _make_llm_result(system_fingerprint=None)
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "System Fingerprint: None" in calls_text
def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock):
"""The usage object must appear in the printed output."""
result = _make_llm_result()
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "Usage:" in calls_text
def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Verify json.dumps is applied to the arguments field (a string)."""
raw_args = '{"x": 42}'
tc = _make_tool_call(arguments=raw_args)
result = _make_llm_result(tool_calls=[tc])
with patch.object(cb, "print_text") as mock_print:
self._invoke(cb, llm_instance, result)
# Check if any call to print_text included the expected (json-encoded) arguments
# json.dumps(raw_args) produces a string starting and ending with quotes
expected_substring = json.dumps(raw_args)
found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list)
assert found, f"Expected {expected_substring} to be printed in one of the calls"
def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock):
"""All optional parameters should be accepted without error."""
result = _make_llm_result()
cb.on_after_invoke(
llm_instance=llm_instance,
result=result,
model=result.model,
credentials={"key": "secret"},
prompt_messages=[_make_user_prompt("q")],
model_parameters={"temperature": 0.9},
tools=[_make_tool("t")],
stop=["<EOS>"],
stream=False,
user="carol",
)
# ===========================================================================
# Tests for on_invoke_error
# ===========================================================================
class TestOnInvokeError:
"""Tests for LoggingCallback.on_invoke_error."""
def _invoke_error(
self,
cb: LoggingCallback,
llm_instance: MagicMock,
ex: Exception,
**kwargs,
):
cb.on_invoke_error(
llm_instance=llm_instance,
ex=ex,
model="gpt-4",
credentials={},
prompt_messages=[],
model_parameters={},
**kwargs,
)
def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock):
"""The [on_llm_invoke_error] banner must be printed."""
with patch.object(cb, "print_text") as mock_print:
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
self._invoke_error(cb, llm_instance, RuntimeError("boom"))
calls_text = " ".join(str(c) for c in mock_print.call_args_list)
assert "[on_llm_invoke_error]" in calls_text
def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock):
"""logger.exception must be called with the exception."""
ex = ValueError("something went wrong")
with patch.object(cb, "print_text"):
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
self._invoke_error(cb, llm_instance, ex)
mock_logger.exception.assert_called_once_with(ex)
def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock):
"""Works with any exception type (TypeError, IOError, etc.)."""
for exc_cls in (TypeError, IOError, KeyError, Exception):
ex = exc_cls("error")
with patch.object(cb, "print_text"):
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger:
self._invoke_error(cb, llm_instance, ex)
mock_logger.exception.assert_called_once_with(ex)
def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock):
"""All optional parameters should be accepted without error."""
ex = RuntimeError("fail")
with patch.object(cb, "print_text"):
with patch("dify_graph.model_runtime.callbacks.logging_callback.logger"):
cb.on_invoke_error(
llm_instance=llm_instance,
ex=ex,
model="gpt-4",
credentials={"key": "secret"},
prompt_messages=[_make_user_prompt("q")],
model_parameters={"temperature": 0.7},
tools=[_make_tool("t")],
stop=["STOP"],
stream=True,
user="dave",
)
# ===========================================================================
# Tests for print_text (inherited from Callback, exercised through LoggingCallback)
# ===========================================================================
class TestPrintText:
"""Verify that print_text from the Callback base class works correctly."""
def test_print_text_with_color(self, cb: LoggingCallback, capsys):
"""print_text with a known colour should emit an ANSI escape sequence."""
cb.print_text("hello", color="blue")
captured = capsys.readouterr()
assert "hello" in captured.out
# ANSI escape codes should be present
assert "\x1b[" in captured.out
def test_print_text_without_color(self, cb: LoggingCallback, capsys):
"""print_text without colour should print plain text."""
cb.print_text("plain text")
captured = capsys.readouterr()
assert "plain text" in captured.out
def test_print_text_all_colours(self, cb: LoggingCallback, capsys):
"""Verify all supported colour keys don't raise."""
for colour in ("blue", "yellow", "pink", "green", "red"):
cb.print_text("x", color=colour)
captured = capsys.readouterr()
# All outputs should contain 'x' (5 calls)
assert captured.out.count("x") >= 5
# ===========================================================================
# Integration-style test: real print_text called (no mocking)
# ===========================================================================
class TestLoggingCallbackIntegration:
"""Light integration tests real print_text calls, just checking no exceptions."""
def test_on_before_invoke_full_run(self, capsys):
"""Full on_before_invoke run with all optional fields verifies real output."""
cb = LoggingCallback()
llm = MagicMock()
msgs = [_make_user_prompt("Who are you?", name="tester")]
tools = [_make_tool("calculator")]
cb.on_before_invoke(
llm_instance=llm,
model="gpt-4-turbo",
credentials={"api_key": "sk-xxx"},
prompt_messages=msgs,
model_parameters={"temperature": 0.8},
tools=tools,
stop=["STOP"],
stream=True,
user="test_user",
)
captured = capsys.readouterr()
assert "gpt-4-turbo" in captured.out
assert "calculator" in captured.out
assert "test_user" in captured.out
assert "STOP" in captured.out
assert "tester" in captured.out
def test_on_new_chunk_full_run(self, capsys):
"""Full on_new_chunk run verifies real stdout write."""
cb = LoggingCallback()
chunk = _make_chunk("streaming token")
cb.on_new_chunk(
llm_instance=MagicMock(),
chunk=chunk,
model="gpt-4",
credentials={},
prompt_messages=[],
model_parameters={},
)
captured = capsys.readouterr()
assert "streaming token" in captured.out
def test_on_after_invoke_full_run_with_tool_calls(self, capsys):
"""Full on_after_invoke run with tool calls verifies real output."""
cb = LoggingCallback()
tc = _make_tool_call("call-99", "do_thing", '{"n": 5}')
result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz")
cb.on_after_invoke(
llm_instance=MagicMock(),
result=result,
model=result.model,
credentials={},
prompt_messages=[],
model_parameters={},
)
captured = capsys.readouterr()
assert "result content" in captured.out
assert "call-99" in captured.out
assert "do_thing" in captured.out
assert "fp-xyz" in captured.out
def test_on_invoke_error_full_run(self, capsys):
"""Full on_invoke_error run just verifies no exception is raised."""
cb = LoggingCallback()
ex = RuntimeError("something bad happened")
# logger.exception writes to stderr; we just confirm it doesn't crash
cb.on_invoke_error(
llm_instance=MagicMock(),
ex=ex,
model="gpt-4",
credentials={},
prompt_messages=[],
model_parameters={},
)
captured = capsys.readouterr()
assert "[on_llm_invoke_error]" in captured.out

View File

@ -0,0 +1,35 @@
from dify_graph.model_runtime.entities.common_entities import I18nObject
class TestI18nObject:
def test_i18n_object_with_both_languages(self):
"""
Test I18nObject when both zh_Hans and en_US are provided.
"""
i18n = I18nObject(zh_Hans="你好", en_US="Hello")
assert i18n.zh_Hans == "你好"
assert i18n.en_US == "Hello"
def test_i18n_object_fallback_to_en_us(self):
"""
Test I18nObject when zh_Hans is missing, it should fallback to en_US.
"""
i18n = I18nObject(en_US="Hello")
assert i18n.zh_Hans == "Hello"
assert i18n.en_US == "Hello"
def test_i18n_object_with_none_zh_hans(self):
"""
Test I18nObject when zh_Hans is None, it should fallback to en_US.
"""
i18n = I18nObject(zh_Hans=None, en_US="Hello")
assert i18n.zh_Hans == "Hello"
assert i18n.en_US == "Hello"
def test_i18n_object_with_empty_zh_hans(self):
"""
Test I18nObject when zh_Hans is an empty string, it should fallback to en_US.
"""
i18n = I18nObject(zh_Hans="", en_US="Hello")
assert i18n.zh_Hans == "Hello"
assert i18n.en_US == "Hello"

View File

@ -0,0 +1,210 @@
import pytest
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessageContent,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
class TestPromptMessageRole:
def test_value_of(self):
assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM
assert PromptMessageRole.value_of("user") == PromptMessageRole.USER
assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT
assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL
with pytest.raises(ValueError, match="invalid prompt message type value invalid"):
PromptMessageRole.value_of("invalid")
class TestPromptMessageEntities:
def test_prompt_message_tool(self):
tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"})
assert tool.name == "test_tool"
assert tool.description == "test desc"
assert tool.parameters == {"foo": "bar"}
def test_prompt_message_function(self):
tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"})
func = PromptMessageFunction(function=tool)
assert func.type == "function"
assert func.function == tool
class TestPromptMessageContent:
def test_text_content(self):
content = TextPromptMessageContent(data="hello")
assert content.type == PromptMessageContentType.TEXT
assert content.data == "hello"
def test_image_content(self):
content = ImagePromptMessageContent(
format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH
)
assert content.type == PromptMessageContentType.IMAGE
assert content.detail == ImagePromptMessageContent.DETAIL.HIGH
assert content.data == "data:image/jpeg;base64,abc"
def test_image_content_url(self):
content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg")
assert content.data == "https://example.com/image.jpg"
def test_audio_content(self):
content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg")
assert content.type == PromptMessageContentType.AUDIO
assert content.data == "data:audio/mpeg;base64,abc"
def test_video_content(self):
content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4")
assert content.type == PromptMessageContentType.VIDEO
assert content.data == "data:video/mp4;base64,abc"
def test_document_content(self):
content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf")
assert content.type == PromptMessageContentType.DOCUMENT
assert content.data == "data:application/pdf;base64,abc"
class TestPromptMessages:
def test_user_prompt_message(self):
msg = UserPromptMessage(content="hello")
assert msg.role == PromptMessageRole.USER
assert msg.content == "hello"
assert msg.is_empty() is False
assert msg.get_text_content() == "hello"
def test_user_prompt_message_complex_content(self):
content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")]
msg = UserPromptMessage(content=content)
assert msg.get_text_content() == "hello world"
# Test validation from dict
msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}])
assert isinstance(msg2.content[0], TextPromptMessageContent)
assert msg2.content[0].data == "hi"
def test_prompt_message_empty(self):
msg = UserPromptMessage(content=None)
assert msg.is_empty() is True
assert msg.get_text_content() == ""
def test_assistant_prompt_message(self):
msg = AssistantPromptMessage(content="thinking...")
assert msg.role == PromptMessageRole.ASSISTANT
assert msg.is_empty() is False
tool_call = AssistantPromptMessage.ToolCall(
id="call_1",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"),
)
msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call])
assert msg_with_tools.is_empty() is False
assert msg_with_tools.role == PromptMessageRole.ASSISTANT
def test_assistant_tool_call_id_transform(self):
tool_call = AssistantPromptMessage.ToolCall(
id=123,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"),
)
assert tool_call.id == "123"
def test_system_prompt_message(self):
msg = SystemPromptMessage(content="you are a bot")
assert msg.role == PromptMessageRole.SYSTEM
assert msg.content == "you are a bot"
def test_tool_prompt_message(self):
# Case 1: Both content and tool_call_id are present
msg = ToolPromptMessage(content="result", tool_call_id="call_1")
assert msg.role == PromptMessageRole.TOOL
assert msg.tool_call_id == "call_1"
assert msg.is_empty() is False
# Case 2: Content is present, but tool_call_id is empty
msg_content_only = ToolPromptMessage(content="result", tool_call_id="")
assert msg_content_only.is_empty() is False
# Case 3: Content is None, but tool_call_id is present
msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1")
assert msg_id_only.is_empty() is False
# Case 4: Both content and tool_call_id are empty
msg_empty = ToolPromptMessage(content=None, tool_call_id="")
assert msg_empty.is_empty() is True
def test_prompt_message_validation_errors(self):
with pytest.raises(KeyError):
# Invalid content type in list
UserPromptMessage(content=[{"type": "invalid", "data": "foo"}])
with pytest.raises(ValueError, match="invalid prompt message"):
# Not a dict or PromptMessageContent
UserPromptMessage(content=[123])
def test_prompt_message_serialization(self):
# Case: content is None
assert UserPromptMessage(content=None).serialize_content(None) is None
# Case: content is str
assert UserPromptMessage(content="hello").serialize_content("hello") == "hello"
# Case: content is list of dict
content_list = [{"type": "text", "data": "hi"}]
msg = UserPromptMessage(content=content_list)
assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}]
# Case: content is Sequence but not list (e.g. tuple)
# To hit line 204, we can call serialize_content manually or
# try to pass a type that pydantic doesn't convert to list in its internal state.
# Actually, let's just call it manually on the instance.
msg = UserPromptMessage(content="test")
content_tuple = (TextPromptMessageContent(data="hi"),)
assert msg.serialize_content(content_tuple) == content_tuple
def test_prompt_message_mixed_content_validation(self):
# Test branch: isinstance(prompt, PromptMessageContent)
# but not (TextPromptMessageContent | MultiModalPromptMessageContent)
# Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
# We need a PromptMessageContent that is NOT Text or MultiModal.
# But PromptMessageContentUnionTypes discriminator handles this usually.
# We can bypass high-level validation by passing the object directly in a list.
class MockContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str
mock_item = MockContent(data="test")
msg = UserPromptMessage(content=[mock_item])
# It should hit line 187 and convert to TextPromptMessageContent
assert isinstance(msg.content[0], TextPromptMessageContent)
assert msg.content[0].data == "test"
def test_prompt_message_get_text_content_branches(self):
# content is None
msg_none = UserPromptMessage(content=None)
assert msg_none.get_text_content() == ""
# content is list but no text content
image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg")
msg_image = UserPromptMessage(content=[image])
assert msg_image.get_text_content() == ""
# content is list with mixed
text = TextPromptMessageContent(data="hello")
msg_mixed = UserPromptMessage(content=[text, image])
assert msg_mixed.get_text_content() == "hello"

View File

@ -0,0 +1,220 @@
from decimal import Decimal
import pytest
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ModelUsage,
ParameterRule,
ParameterType,
PriceConfig,
PriceInfo,
PriceType,
ProviderModel,
)
class TestModelType:
def test_value_of(self):
assert ModelType.value_of("text-generation") == ModelType.LLM
assert ModelType.value_of(ModelType.LLM) == ModelType.LLM
assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING
assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING
assert ModelType.value_of("reranking") == ModelType.RERANK
assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK
assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT
assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT
assert ModelType.value_of("tts") == ModelType.TTS
assert ModelType.value_of(ModelType.TTS) == ModelType.TTS
assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION
with pytest.raises(ValueError, match="invalid origin model type invalid"):
ModelType.value_of("invalid")
def test_to_origin_model_type(self):
assert ModelType.LLM.to_origin_model_type() == "text-generation"
assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings"
assert ModelType.RERANK.to_origin_model_type() == "reranking"
assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text"
assert ModelType.TTS.to_origin_model_type() == "tts"
assert ModelType.MODERATION.to_origin_model_type() == "moderation"
# Testing the else branch in to_origin_model_type
# Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it.
# But if we look at the implementation:
# if self == self.LLM: ... elif ... else: raise ValueError
# We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise.
# Actually, adding a new member to an enum at runtime is possible but messy.
# Let's see if we can trigger it.
class TestFetchFrom:
def test_values(self):
assert FetchFrom.PREDEFINED_MODEL == "predefined-model"
assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model"
class TestModelFeature:
def test_values(self):
assert ModelFeature.TOOL_CALL == "tool-call"
assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call"
assert ModelFeature.AGENT_THOUGHT == "agent-thought"
assert ModelFeature.VISION == "vision"
assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call"
assert ModelFeature.DOCUMENT == "document"
assert ModelFeature.VIDEO == "video"
assert ModelFeature.AUDIO == "audio"
assert ModelFeature.STRUCTURED_OUTPUT == "structured-output"
class TestDefaultParameterName:
def test_value_of(self):
assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE
assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P
with pytest.raises(ValueError, match="invalid parameter name invalid"):
DefaultParameterName.value_of("invalid")
class TestParameterType:
def test_values(self):
assert ParameterType.FLOAT == "float"
assert ParameterType.INT == "int"
assert ParameterType.STRING == "string"
assert ParameterType.BOOLEAN == "boolean"
assert ParameterType.TEXT == "text"
class TestModelPropertyKey:
def test_values(self):
assert ModelPropertyKey.MODE == "mode"
assert ModelPropertyKey.CONTEXT_SIZE == "context_size"
class TestProviderModel:
def test_provider_model(self):
model = ProviderModel(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
)
assert model.model == "gpt-4"
assert model.support_structure_output is False
model_with_features = ProviderModel(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
features=[ModelFeature.STRUCTURED_OUTPUT],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
)
assert model_with_features.support_structure_output is True
class TestParameterRule:
def test_parameter_rule(self):
rule = ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=0.7,
min=0.0,
max=1.0,
precision=2,
)
assert rule.name == "temperature"
assert rule.default == 0.7
class TestPriceConfig:
def test_price_config(self):
config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD")
assert config.input == Decimal("0.01")
assert config.output == Decimal("0.02")
class TestAIModelEntity:
def test_ai_model_entity_no_json_schema(self):
entity = AIModelEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
parameter_rules=[
ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT)
],
)
assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or [])
def test_ai_model_entity_with_json_schema(self):
# Case: json_schema in parameter rules, features is None
entity = AIModelEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
parameter_rules=[
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
],
)
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
def test_ai_model_entity_with_json_schema_and_features_empty(self):
# Case: json_schema in parameter rules, features is empty list
entity = AIModelEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
parameter_rules=[
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
],
)
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
def test_ai_model_entity_with_json_schema_and_other_features(self):
# Case: json_schema in parameter rules, features has other things
entity = AIModelEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
features=[ModelFeature.VISION],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192},
parameter_rules=[
ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING)
],
)
assert ModelFeature.STRUCTURED_OUTPUT in entity.features
assert ModelFeature.VISION in entity.features
class TestModelUsage:
def test_model_usage(self):
usage = ModelUsage()
assert isinstance(usage, ModelUsage)
class TestPriceType:
def test_values(self):
assert PriceType.INPUT == "input"
assert PriceType.OUTPUT == "output"
class TestPriceInfo:
def test_price_info(self):
info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD")
assert info.total_amount == Decimal("0.05")

View File

@ -0,0 +1,63 @@
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
class TestInvokeErrors:
def test_invoke_error_with_description(self):
error = InvokeError("Custom description")
assert error.description == "Custom description"
assert str(error) == "Custom description"
assert isinstance(error, ValueError)
def test_invoke_error_without_description(self):
error = InvokeError()
assert error.description is None
assert str(error) == "InvokeError"
def test_invoke_connection_error(self):
# Now preserves class-level description
error = InvokeConnectionError()
assert error.description == "Connection Error"
assert str(error) == "Connection Error"
assert isinstance(error, InvokeError)
# Test with explicit description
error_with_desc = InvokeConnectionError("Connection Error")
assert error_with_desc.description == "Connection Error"
assert str(error_with_desc) == "Connection Error"
def test_invoke_server_unavailable_error(self):
error = InvokeServerUnavailableError()
assert error.description == "Server Unavailable Error"
assert str(error) == "Server Unavailable Error"
assert isinstance(error, InvokeError)
def test_invoke_rate_limit_error(self):
error = InvokeRateLimitError()
assert error.description == "Rate Limit Error"
assert str(error) == "Rate Limit Error"
assert isinstance(error, InvokeError)
def test_invoke_authorization_error(self):
error = InvokeAuthorizationError()
assert error.description == "Incorrect model credentials provided, please check and try again. "
assert str(error) == "Incorrect model credentials provided, please check and try again. "
assert isinstance(error, InvokeError)
def test_invoke_bad_request_error(self):
error = InvokeBadRequestError()
assert error.description == "Bad Request Error"
assert str(error) == "Bad Request Error"
assert isinstance(error, InvokeError)
def test_invoke_error_inheritance(self):
# Test that we can override the default description in subclasses
error = InvokeBadRequestError("Overridden Error")
assert error.description == "Overridden Error"
assert str(error) == "Overridden Error"

View File

@ -0,0 +1,336 @@
import decimal
from unittest.mock import MagicMock, patch
import pytest
from redis import RedisError
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
PriceType,
)
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class TestAIModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def ai_model(self, mock_plugin_model_provider):
return AIModel(
tenant_id="tenant_123",
model_type=ModelType.LLM,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_invoke_error_mapping(self, ai_model):
mapping = ai_model._invoke_error_mapping
assert InvokeConnectionError in mapping
assert InvokeServerUnavailableError in mapping
assert InvokeRateLimitError in mapping
assert InvokeAuthorizationError in mapping
assert InvokeBadRequestError in mapping
assert PluginDaemonInnerError in mapping
assert ValueError in mapping
def test_transform_invoke_error(self, ai_model):
# Case: mapped error (InvokeAuthorizationError)
err = Exception("Original error")
with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}):
transformed = ai_model._transform_invoke_error(err)
assert isinstance(transformed, InvokeAuthorizationError)
assert "Incorrect model credentials provided" in str(transformed.description)
# Case: mapped error (InvokeError subclass)
with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}):
transformed = ai_model._transform_invoke_error(err)
assert isinstance(transformed, InvokeError)
assert "[test_provider]" in transformed.description
# Case: mapped error (not InvokeError)
class CustomNonInvokeError(Exception):
pass
with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}):
transformed = ai_model._transform_invoke_error(err)
assert transformed == err
# Case: unmapped error
unmapped_err = Exception("Unmapped")
transformed = ai_model._transform_invoke_error(unmapped_err)
assert isinstance(transformed, InvokeError)
assert "Error: Unmapped" in transformed.description
def test_get_price(self, ai_model):
model_name = "test_model"
credentials = {"key": "value"}
# Mock get_model_schema
mock_schema = MagicMock(spec=AIModelEntity)
mock_schema.pricing = PriceConfig(
input=decimal.Decimal("0.002"),
output=decimal.Decimal("0.004"),
unit=decimal.Decimal(1000), # 1000 tokens per unit
currency="USD",
)
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
# Test INPUT
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000)
assert price_info.unit_price == decimal.Decimal("0.002")
# Test OUTPUT
price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000)
assert price_info.unit_price == decimal.Decimal("0.004")
# Case: unit_price is None (returns zeroed PriceInfo)
mock_schema.pricing = None
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000)
assert price_info.total_amount == decimal.Decimal("0.0")
def test_get_price_no_price_config_error(self, ai_model):
model_name = "test_model"
# We need it to be truthy at line 107 and 112 but falsy at line 127.
class ChangingPriceConfig:
def __init__(self):
self.input = decimal.Decimal("0.01")
self.unit = decimal.Decimal(1)
self.currency = "USD"
self.called = 0
def __bool__(self):
self.called += 1
return self.called <= 2
mock_schema = MagicMock()
mock_schema.pricing = ChangingPriceConfig()
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
with pytest.raises(ValueError) as excinfo:
ai_model.get_price(model_name, {}, PriceType.INPUT, 1000)
assert "Price config not found" in str(excinfo.value)
def test_get_model_schema_cache_hit(self, ai_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis:
mock_redis.get.return_value = mock_schema.model_dump_json().encode()
schema = ai_model.get_model_schema(model_name, credentials)
assert schema.model == "test_model"
mock_redis.get.assert_called_once()
def test_get_model_schema_cache_miss(self, ai_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = None
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = mock_schema
schema = ai_model.get_model_schema(model_name, credentials)
assert schema == mock_schema
mock_manager.get_model_schema.assert_called_once()
mock_redis.setex.assert_called_once()
def test_get_model_schema_redis_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.side_effect = RedisError("Connection refused")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_manager.get_model_schema.assert_called_once()
def test_get_model_schema_validation_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = b"invalid json"
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
# This should trigger ValidationError at line 166 and go to delete()
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_redis.delete.assert_called()
def test_get_model_schema_redis_delete_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = b'{"invalid": "schema"}'
mock_redis.delete.side_effect = RedisError("Delete failed")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_redis.delete.assert_called()
def test_get_model_schema_redis_setex_error(self, ai_model):
model_name = "test_model"
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = None
mock_redis.setex.side_effect = RuntimeError("Setex failed")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = mock_schema
schema = ai_model.get_model_schema(model_name, {})
assert schema == mock_schema
mock_redis.setex.assert_called()
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model):
model_name = "test_model"
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[
ParameterRule(
name="invalid",
use_template="invalid_template_name",
label=I18nObject(en_US="Invalid"),
type=ParameterType.FLOAT,
)
],
)
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
assert schema.parameter_rules[0].use_template == "invalid_template_name"
def test_get_customizable_model_schema_from_credentials(self, ai_model):
model_name = "test_model"
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[
ParameterRule(
name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
help=I18nObject(en_US=""),
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
help=I18nObject(en_US="", zh_Hans=""),
),
ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING),
],
)
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
assert schema.parameter_rules[0].max == 1.0
assert schema.parameter_rules[1].help.en_US != ""
assert schema.parameter_rules[2].help.zh_Hans != ""
assert schema.parameter_rules[3].use_template is None
def test_get_customizable_model_schema_from_credentials_none(self, ai_model):
with patch.object(AIModel, "get_customizable_model_schema", return_value=None):
schema = ai_model.get_customizable_model_schema_from_credentials("model", {})
assert schema is None
def test_get_customizable_model_schema_default(self, ai_model):
assert ai_model.get_customizable_model_schema("model", {}) is None
def test_get_default_parameter_rule_variable_map(self, ai_model):
# Valid
res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
assert res["default"] == 0.0
# Invalid
with pytest.raises(Exception) as excinfo:
ai_model._get_default_parameter_rule_variable_map("invalid_name")
assert "Invalid model parameter rule name" in str(excinfo.value)

View File

@ -0,0 +1,476 @@
import logging
from collections.abc import Generator, Iterator, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module
# Access large_language_model members via llm_module to avoid partial import issues in CI
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType, PriceInfo
from dify_graph.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks
def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage:
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1),
prompt_price=Decimal(prompt_tokens) * Decimal("0.001"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1),
completion_price=Decimal(completion_tokens) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"),
currency="USD",
latency=0.0,
)
def _tool_call_delta(
*,
tool_call_id: str,
tool_type: str = "function",
function_name: str = "",
function_arguments: str = "",
) -> AssistantPromptMessage.ToolCall:
return AssistantPromptMessage.ToolCall(
id=tool_call_id,
type=tool_type,
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments),
)
def _chunk(
*,
model: str = "test-model",
content: str | list[Any] | None = None,
tool_calls: list[AssistantPromptMessage.ToolCall] | None = None,
usage: LLMUsage | None = None,
system_fingerprint: str | None = None,
) -> LLMResultChunk:
return LLMResultChunk(
model=model,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []),
usage=usage,
),
)
@dataclass
class SpyCallback(Callback):
raise_error: bool = False
before: list[dict[str, Any]] = field(default_factory=list)
new_chunk: list[dict[str, Any]] = field(default_factory=list)
after: list[dict[str, Any]] = field(default_factory=list)
error: list[dict[str, Any]] = field(default_factory=list)
def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override]
self.before.append(kwargs)
def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override]
self.new_chunk.append(kwargs)
def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override]
self.after.append(kwargs)
def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override]
self.error.append(kwargs)
class _TestLLM(llm_module.LargeLanguageModel):
def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override]
return PriceInfo(
unit_price=Decimal("0.01"),
unit=Decimal(1),
total_amount=Decimal(tokens) * Decimal("0.01"),
currency="USD",
)
def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override]
return RuntimeError(f"transformed: {error}")
@pytest.fixture
def llm() -> _TestLLM:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
return _TestLLM.model_construct(
tenant_id="tenant",
model_type=ModelType.LLM,
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
started_at=1.0,
)
def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123"))
assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123"
def test_run_callbacks_no_callbacks_noop() -> None:
invoked: list[int] = []
llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1))
llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1))
assert invoked == []
def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None:
class Boom:
raise_error = False
caplog.set_level(logging.WARNING)
llm_module._run_callbacks(
[Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom"))
)
assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records)
def test_run_callbacks_reraises_when_raise_error_true() -> None:
class Boom:
raise_error = True
with pytest.raises(ValueError, match="boom"):
llm_module._run_callbacks(
[Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom"))
)
def test_get_or_create_tool_call_empty_id_returns_last() -> None:
calls = [
_tool_call_delta(tool_call_id="id1", function_name="a"),
_tool_call_delta(tool_call_id="id2", function_name="b"),
]
assert llm_module._get_or_create_tool_call(calls, "") is calls[-1]
def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None:
with pytest.raises(ValueError, match="tool_call_id is empty"):
llm_module._get_or_create_tool_call([], "")
def test_get_or_create_tool_call_creates_if_missing() -> None:
calls: list[AssistantPromptMessage.ToolCall] = []
tool_call = llm_module._get_or_create_tool_call(calls, "new-id")
assert tool_call.id == "new-id"
assert tool_call.function.name == ""
assert tool_call.function.arguments == ""
assert calls == [tool_call]
def test_get_or_create_tool_call_returns_existing_when_found() -> None:
existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}")
calls = [existing]
assert llm_module._get_or_create_tool_call(calls, "same-id") is existing
def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None:
tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{")
delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}")
llm_module._merge_tool_call_delta(tool_call, delta)
assert tool_call.id == "id2"
assert tool_call.type == "function"
assert tool_call.function.name == "y"
assert tool_call.function.arguments == "{}"
def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed"))
delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{")
existing: list[AssistantPromptMessage.ToolCall] = []
llm_module._increase_tool_call([delta], existing)
assert len(existing) == 1
assert existing[0].id == "chatcmpl-tool-fixed"
assert existing[0].function.name == "fn"
assert existing[0].function.arguments == "{"
def test_increase_tool_call_merges_incremental_arguments() -> None:
existing: list[AssistantPromptMessage.ToolCall] = []
llm_module._increase_tool_call(
[_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing
)
llm_module._increase_tool_call(
[_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing
)
assert len(existing) == 1
assert existing[0].function.name == "fn"
assert existing[0].function.arguments == "{}"
@pytest.mark.parametrize(
("content", "expected_type"),
[
("hello", str),
([TextPromptMessageContent(data="hello")], list),
],
)
def test_build_llm_result_from_chunks_accumulates_and_raises_error(
content: str | list[TextPromptMessageContent],
expected_type: type,
monkeypatch: pytest.MonkeyPatch,
caplog: pytest.LogCaptureFixture,
) -> None:
monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain"))
caplog.set_level(logging.DEBUG)
tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}")
first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1")
def iter_with_error() -> Iterator[LLMResultChunk]:
yield first
raise RuntimeError("drain boom")
with pytest.raises(RuntimeError, match="drain boom"):
_build_llm_result_from_chunks(
model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error()
)
assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records)
def test_build_llm_result_from_chunks_empty_iterator() -> None:
def empty() -> Iterator[LLMResultChunk]:
if False: # pragma: no cover
yield _chunk()
return
result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty())
assert result.message.content == []
assert result.usage.total_tokens == 0
assert result.system_fingerprint is None
def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None:
chunks = iter([_chunk(content="first"), _chunk(content="second")])
result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks)
assert result.message.content == "firstsecond"
def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None:
invoked: dict[str, Any] = {}
class FakePluginModelClient:
def invoke_llm(self, **kwargs: Any) -> str:
invoked.update(kwargs)
return "ok"
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),)
result = llm_module._invoke_llm_via_plugin(
tenant_id="t",
user_id="u",
plugin_id="p",
provider="prov",
model="m",
credentials={"k": "v"},
model_parameters={"temp": 1},
prompt_messages=prompt_messages,
tools=None,
stop=("a", "b"),
stream=True,
)
assert result == "ok"
assert invoked["prompt_messages"] == list(prompt_messages)
assert invoked["stop"] == ["a", "b"]
def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None:
llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
assert (
llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result
)
def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
chunks = iter([_chunk(content="hello", usage=_usage(1, 1))])
result = llm_module._normalize_non_stream_plugin_result(
model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks
)
assert isinstance(result, LLMResult)
assert result.message.content == "hello"
def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
lambda **_: plugin_result,
)
cb = SpyCallback()
prompt_messages = [UserPromptMessage(content="hi")]
result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb])
assert isinstance(result, LLMResult)
assert result.prompt_messages == prompt_messages
assert len(cb.before) == 1
assert len(cb.after) == 1
assert cb.after[0]["result"].prompt_messages == prompt_messages
def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
plugin_chunks = iter(
[
_chunk(model="m1", content="a"),
_chunk(
model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp"
),
_chunk(model="m3", content=None),
]
)
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
lambda **_: plugin_chunks,
)
cb = SpyCallback()
prompt_messages = [UserPromptMessage(content="hi")]
gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb])
assert isinstance(gen, Generator)
chunks = list(gen)
assert len(chunks) == 3
assert all(chunk.prompt_messages == prompt_messages for chunk in chunks)
assert len(cb.before) == 1
assert len(cb.new_chunk) == 3
assert len(cb.after) == 1
final_result: LLMResult = cb.after[0]["result"]
assert final_result.model == "m3"
assert final_result.system_fingerprint == "fp"
assert isinstance(final_result.message.content, list)
assert [c.data for c in final_result.message.content] == ["a", "b"]
assert final_result.usage.total_tokens == 5
def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
def boom(**_: Any) -> Any:
raise ValueError("plugin down")
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom
)
cb = SpyCallback()
with pytest.raises(RuntimeError, match="transformed: plugin down"):
llm.invoke(
model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb]
)
assert len(cb.error) == 1
assert isinstance(cb.error[0]["ex"], ValueError)
def test_invoke_raises_not_implemented_for_unsupported_result_type(
llm: _TestLLM, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result")
monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result")
with pytest.raises(NotImplementedError, match="unsupported invoke result type"):
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
captured_callbacks: list[list[Callback]] = []
class FakeLoggingCallback(SpyCallback):
pass
monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback)
monkeypatch.setattr(llm_module.dify_config, "DEBUG", True)
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()),
)
original_trigger = llm._trigger_before_invoke_callbacks
def spy_trigger(*args: Any, **kwargs: Any) -> None:
captured_callbacks.append(list(kwargs["callbacks"]))
original_trigger(*args, **kwargs)
monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger)
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0])
def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False)
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0
def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True)
class FakePluginModelClient:
def get_llm_num_tokens(self, **kwargs: Any) -> int:
assert kwargs["tenant_id"] == "tenant"
assert kwargs["plugin_id"] == "plugin-id"
assert kwargs["provider"] == "provider"
assert kwargs["model_type"] == "llm"
return 42
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42
def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5)
llm.started_at = 1.0
usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5)
assert usage.total_tokens == 15
assert usage.total_price == Decimal("0.15")
assert usage.latency == 3.5
def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None:
def broken() -> Iterator[LLMResultChunk]:
yield _chunk(content="ok")
raise ValueError("chunk stream broken")
gen = llm._invoke_result_generator(
model="m",
result=broken(),
credentials={},
prompt_messages=[UserPromptMessage(content="u")],
model_parameters={},
callbacks=[SpyCallback()],
)
with pytest.raises(RuntimeError, match="transformed: chunk stream broken"):
list(gen)

View File

@ -0,0 +1,90 @@
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
class TestModerationModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def moderation_model(self, mock_plugin_model_provider):
return ModerationModel(
tenant_id="tenant_123",
model_type=ModelType.MODERATION,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, moderation_model):
assert moderation_model.model_type == ModelType.MODERATION
def test_invoke_success(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
user = "user_123"
with (
patch("core.plugin.impl.model.PluginModelClient") as mock_client_class,
patch("time.perf_counter", return_value=1.0),
):
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.return_value = True
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user)
assert result is True
assert moderation_model.started_at == 1.0
mock_client.invoke_moderation.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
text=text,
)
def test_invoke_success_no_user(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.return_value = False
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text)
assert result is False
mock_client.invoke_moderation.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
text=text,
)
def test_invoke_exception(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
moderation_model.invoke(model=model_name, credentials=credentials, text=text)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)

View File

@ -0,0 +1,181 @@
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
@pytest.fixture
def rerank_model() -> RerankModel:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
return RerankModel.model_construct(
tenant_id="tenant",
model_type=ModelType.RERANK,
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
)
def test_model_type_is_rerank_by_default() -> None:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
model = RerankModel(
tenant_id="tenant",
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
)
assert model.model_type == ModelType.RERANK
def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)])
class FakePluginModelClient:
def __init__(self) -> None:
self.invoke_rerank_called_with: dict[str, Any] | None = None
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
self.invoke_rerank_called_with = kwargs
return expected
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
result = rerank_model.invoke(
model="rerank",
credentials={"k": "v"},
query="q",
docs=["d1", "d2"],
score_threshold=0.2,
top_n=10,
user="user-1",
)
assert result == expected
assert fake_client.invoke_rerank_called_with == {
"tenant_id": "tenant",
"user_id": "user-1",
"plugin_id": "plugin-id",
"provider": "provider",
"model": "rerank",
"credentials": {"k": "v"},
"query": "q",
"docs": ["d1", "d2"],
"score_threshold": 0.2,
"top_n": 10,
}
def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
class FakePluginModelClient:
def __init__(self) -> None:
self.kwargs: dict[str, Any] | None = None
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
self.kwargs = kwargs
return RerankResult(model="m", docs=[])
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
assert fake_client.kwargs is not None
assert fake_client.kwargs["user_id"] == "unknown"
def test_invoke_transforms_and_raises_on_plugin_error(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
class FakePluginModelClient:
def invoke_rerank(self, **_: Any) -> RerankResult:
raise ValueError("plugin down")
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
with pytest.raises(RuntimeError, match="transformed: plugin down"):
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
def test_invoke_multimodal_calls_plugin_and_passes_args(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)])
class FakePluginModelClient:
def __init__(self) -> None:
self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None
def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult:
self.invoke_multimodal_rerank_called_with = kwargs
return expected
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
query = {"type": "text", "text": "q"}
docs = [{"type": "text", "text": "d1"}]
result = rerank_model.invoke_multimodal_rerank(
model="mm",
credentials={"k": "v"},
query=query,
docs=docs,
score_threshold=None,
top_n=None,
user=None,
)
assert result == expected
assert fake_client.invoke_multimodal_rerank_called_with is not None
assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant"
assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown"
assert fake_client.invoke_multimodal_rerank_called_with["query"] == query
assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs
def test_invoke_multimodal_transforms_and_raises_on_plugin_error(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
class FakePluginModelClient:
def invoke_multimodal_rerank(self, **_: Any) -> RerankResult:
raise ValueError("plugin down")
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
with pytest.raises(RuntimeError, match="transformed: plugin down"):
rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}])

View File

@ -0,0 +1,87 @@
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
class TestSpeech2TextModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def speech2text_model(self, mock_plugin_model_provider):
return Speech2TextModel(
tenant_id="tenant_123",
model_type=ModelType.SPEECH2TEXT,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, speech2text_model):
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
def test_invoke_success(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
user = "user_123"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.return_value = "transcribed text"
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user)
assert result == "transcribed text"
mock_client.invoke_speech_to_text.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
file=file,
)
def test_invoke_success_no_user(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.return_value = "transcribed text"
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
assert result == "transcribed text"
mock_client.invoke_speech_to_text.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
file=file,
)
def test_invoke_exception(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)

View File

@ -0,0 +1,185 @@
from unittest.mock import MagicMock, patch
import pytest
from core.entities.embedding_type import EmbeddingInputType
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
class TestTextEmbeddingModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def text_embedding_model(self, mock_plugin_model_provider):
return TextEmbeddingModel(
tenant_id="tenant_123",
model_type=ModelType.TEXT_EMBEDDING,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, text_embedding_model):
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
def test_invoke_with_texts(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello", "world"]
user = "user_123"
expected_result = MagicMock(spec=EmbeddingResult)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.return_value = expected_result
result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user)
assert result == expected_result
mock_client.invoke_text_embedding.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
texts=texts,
input_type=EmbeddingInputType.DOCUMENT,
)
def test_invoke_with_multimodel_documents(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
multimodel_documents = [{"type": "text", "text": "hello"}]
expected_result = MagicMock(spec=EmbeddingResult)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_multimodal_embedding.return_value = expected_result
result = text_embedding_model.invoke(
model=model_name, credentials=credentials, multimodel_documents=multimodel_documents
)
assert result == expected_result
mock_client.invoke_multimodal_embedding.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
documents=multimodel_documents,
input_type=EmbeddingInputType.DOCUMENT,
)
def test_invoke_no_input(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
with pytest.raises(ValueError) as excinfo:
text_embedding_model.invoke(model=model_name, credentials=credentials)
assert "No texts or files provided" in str(excinfo.value)
def test_invoke_precedence(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello"]
multimodel_documents = [{"type": "text", "text": "world"}]
expected_result = MagicMock(spec=EmbeddingResult)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.return_value = expected_result
result = text_embedding_model.invoke(
model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents
)
assert result == expected_result
mock_client.invoke_text_embedding.assert_called_once()
mock_client.invoke_multimodal_embedding.assert_not_called()
def test_invoke_exception(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello"]
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
def test_get_num_tokens(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello", "world"]
expected_tokens = [1, 1]
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.get_text_embedding_num_tokens.return_value = expected_tokens
result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts)
assert result == expected_tokens
mock_client.get_text_embedding_num_tokens.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
texts=texts,
)
def test_get_context_size(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
# Test case 1: Context size in schema
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size(model_name, credentials) == 2048
# Test case 2: No schema
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
# Test case 3: Context size NOT in schema properties
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
def test_get_max_chunks(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
# Test case 1: Max chunks in schema
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 10
# Test case 2: No schema
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
# Test case 3: Max chunks NOT in schema properties
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1

View File

@ -0,0 +1,131 @@
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
class TestTTSModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def tts_model(self, mock_plugin_model_provider):
return TTSModel(
tenant_id="tenant_123",
model_type=ModelType.TTS,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, tts_model):
assert tts_model.model_type == ModelType.TTS
def test_invoke_success(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
user = "user_123"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.return_value = [b"audio_chunk"]
result = tts_model.invoke(
model=model_name,
tenant_id=tenant_id,
credentials=credentials,
content_text=content_text,
voice=voice,
user=user,
)
assert list(result) == [b"audio_chunk"]
mock_client.invoke_tts.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def test_invoke_success_no_user(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.return_value = [b"audio_chunk"]
result = tts_model.invoke(
model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice
)
assert list(result) == [b"audio_chunk"]
mock_client.invoke_tts.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def test_invoke_exception(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
tts_model.invoke(
model=model_name,
tenant_id=tenant_id,
credentials=credentials,
content_text=content_text,
voice=voice,
)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
def test_get_tts_model_voices(self, tts_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
language = "en-US"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}]
result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language)
assert result == [{"name": "Voice1"}]
mock_client.get_tts_model_voices.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
language=language,
)

View File

@ -0,0 +1,96 @@
from unittest.mock import MagicMock, patch
import dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module
from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class TestGPT2Tokenizer:
def setup_method(self):
# Reset the global tokenizer before each test to ensure we test initialization
gpt2_tokenizer_module._tokenizer = None
def test_get_encoder_tiktoken(self):
"""
Test that get_encoder successfully uses tiktoken when available.
"""
mock_encoding = MagicMock()
# Mock tiktoken to be sure it's used
with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding:
encoder = GPT2Tokenizer.get_encoder()
assert encoder == mock_encoding
mock_get_encoding.assert_called_once_with("gpt2")
# Verify singleton behavior within the same test
encoder2 = GPT2Tokenizer.get_encoder()
assert encoder2 is encoder
assert mock_get_encoding.call_count == 1
def test_get_encoder_tiktoken_fallback(self):
"""
Test that get_encoder falls back to transformers when tiktoken fails.
"""
# patch tiktoken.get_encoding to raise an exception
with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")):
# patch transformers.GPT2Tokenizer
with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained:
mock_transformer_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_transformer_tokenizer
with patch(
"dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger"
) as mock_logger:
encoder = GPT2Tokenizer.get_encoder()
assert encoder == mock_transformer_tokenizer
mock_from_pretrained.assert_called_once()
mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
def test_get_num_tokens(self):
"""
Test get_num_tokens returns the correct count.
"""
mock_encoder = MagicMock()
mock_encoder.encode.return_value = [1, 2, 3, 4, 5]
with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder):
tokens_count = GPT2Tokenizer.get_num_tokens("test text")
assert tokens_count == 5
mock_encoder.encode.assert_called_once_with("test text")
def test_get_num_tokens_by_gpt2_direct(self):
"""
Test _get_num_tokens_by_gpt2 directly.
"""
mock_encoder = MagicMock()
mock_encoder.encode.return_value = [1, 2]
with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder):
tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello")
assert tokens_count == 2
mock_encoder.encode.assert_called_once_with("hello")
def test_get_encoder_already_initialized(self):
"""
Test that if _tokenizer is already set, it returns it immediately.
"""
mock_existing_tokenizer = MagicMock()
gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer
# Tiktoken should not be called if already initialized
with patch("tiktoken.get_encoding") as mock_get_encoding:
encoder = GPT2Tokenizer.get_encoder()
assert encoder == mock_existing_tokenizer
mock_get_encoding.assert_not_called()
def test_get_encoder_thread_safety(self):
"""
Simple test to ensure the lock is used.
"""
mock_encoding = MagicMock()
with patch("tiktoken.get_encoding", return_value=mock_encoding):
# We patch the lock in the module
with patch("dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock:
encoder = GPT2Tokenizer.get_encoder()
assert encoder == mock_encoding
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()

View File

@ -0,0 +1,522 @@
import logging
from datetime import datetime
from threading import Lock
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from redis import RedisError
import contexts
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
def _provider_entity(
*,
provider: str,
supported_model_types: list[ModelType] | None = None,
models: list[AIModelEntity] | None = None,
icon_small: I18nObject | None = None,
icon_small_dark: I18nObject | None = None,
) -> ProviderEntity:
return ProviderEntity(
provider=provider,
label=I18nObject(en_US=provider),
supported_model_types=supported_model_types or [ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=models or [],
icon_small=icon_small,
icon_small_dark=icon_small_dark,
)
def _plugin_provider(
*, plugin_id: str, declaration: ProviderEntity, provider: str = "provider"
) -> PluginModelProviderEntity:
return PluginModelProviderEntity.model_construct(
id=f"{plugin_id}-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider=provider,
tenant_id="tenant",
plugin_unique_identifier=f"{plugin_id}-uid",
plugin_id=plugin_id,
declaration=declaration,
)
@pytest.fixture(autouse=True)
def _reset_plugin_model_provider_context() -> None:
contexts.plugin_model_providers_lock.set(Lock())
contexts.plugin_model_providers.set(None)
@pytest.fixture
def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
manager = MagicMock()
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager)
return manager
@pytest.fixture
def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory:
return ModelProviderFactory(tenant_id="tenant")
def test_get_plugin_model_providers_initializes_context_on_lookup_error(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
original_get = contexts.plugin_model_providers.get
calls = {"n": 0}
def flaky_get() -> Any:
calls["n"] += 1
if calls["n"] == 1:
raise LookupError
return original_get()
monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get)
providers = factory.get_plugin_model_providers()
assert len(providers) == 1
assert providers[0].declaration.provider == "langgenius/openai/openai"
def test_get_plugin_model_providers_caches_and_does_not_refetch(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
first = factory.get_plugin_model_providers()
second = factory.get_plugin_model_providers()
assert first is second
fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant")
def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
d1 = _provider_entity(provider="openai")
d2 = _provider_entity(provider="anthropic")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=d1),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=d2),
]
providers = factory.get_providers()
assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"]
def test_get_plugin_model_provider_converts_short_provider_id(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
provider = factory.get_plugin_model_provider("openai")
assert provider.declaration.provider == "langgenius/openai/openai"
def test_get_plugin_model_provider_raises_on_invalid_provider(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
with pytest.raises(ValueError, match="Invalid provider"):
factory.get_plugin_model_provider("langgenius/unknown/unknown")
def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
schema = factory.get_provider_schema("openai")
assert schema.provider == "langgenius/openai/openai"
def test_provider_credentials_validate_errors_when_schema_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.provider_credential_schema = None
monkeypatch.setattr(
factory,
"get_plugin_model_provider",
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
)
with pytest.raises(ValueError, match="does not have provider_credential_schema"):
factory.provider_credentials_validate(provider="openai", credentials={"x": "y"})
def test_provider_credentials_validate_filters_and_calls_plugin_validation(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.provider_credential_schema = MagicMock()
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
fake_validator = MagicMock()
fake_validator.validate_and_filter.return_value = {"filtered": True}
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator",
lambda _: fake_validator,
)
filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True})
assert filtered == {"filtered": True}
fake_plugin_manager.validate_provider_credentials.assert_called_once()
kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs
assert kwargs["plugin_id"] == "langgenius/openai"
assert kwargs["provider"] == "provider"
assert kwargs["credentials"] == {"filtered": True}
def test_model_credentials_validate_errors_when_schema_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.model_credential_schema = None
monkeypatch.setattr(
factory,
"get_plugin_model_provider",
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
)
with pytest.raises(ValueError, match="does not have model_credential_schema"):
factory.model_credentials_validate(
provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"}
)
def test_model_credentials_validate_filters_and_calls_plugin_validation(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.model_credential_schema = MagicMock()
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
fake_validator = MagicMock()
fake_validator.validate_and_filter.return_value = {"filtered": True}
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator",
lambda *_: fake_validator,
)
filtered = factory.model_credentials_validate(
provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True}
)
assert filtered == {"filtered": True}
kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs
assert kwargs["plugin_id"] == "langgenius/openai"
assert kwargs["provider"] == "provider"
assert kwargs["model_type"] == "text-embedding"
assert kwargs["model"] == "m"
assert kwargs["credentials"] == {"filtered": True}
def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None:
model_schema = AIModelEntity(
model="m",
label=I18nObject(en_US="m"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = model_schema.model_dump_json().encode()
assert (
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"})
== model_schema
)
def test_get_model_schema_cache_invalid_json_deletes_key(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = b'{"model":"m"}'
factory.plugin_model_manager.get_model_schema.return_value = None
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
assert mock_redis.delete.called
assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records)
def test_get_model_schema_cache_delete_redis_error_is_logged(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = b'{"model":"m"}'
mock_redis.delete.side_effect = RedisError("nope")
factory.plugin_model_manager.get_model_schema.return_value = None
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records)
def test_get_model_schema_redis_get_error_falls_back_to_plugin(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
factory.plugin_model_manager.get_model_schema.return_value = None
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.side_effect = RedisError("down")
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records)
def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
model_schema = AIModelEntity(
model="m",
label=I18nObject(en_US="m"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
factory.plugin_model_manager.get_model_schema.return_value = model_schema
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_redis.setex.side_effect = RedisError("nope")
assert (
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
== model_schema
)
assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records)
@pytest.mark.parametrize(
("model_type", "expected_class"),
[
(ModelType.LLM, "LargeLanguageModel"),
(ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"),
(ModelType.RERANK, "RerankModel"),
(ModelType.SPEECH2TEXT, "Speech2TextModel"),
(ModelType.MODERATION, "ModerationModel"),
(ModelType.TTS, "TTSModel"),
],
)
def test_get_model_type_instance_dispatches_by_type(
factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
sentinel = object()
monkeypatch.setattr(
f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}",
MagicMock(model_validate=lambda _: sentinel),
)
assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel
def test_get_model_type_instance_raises_on_unsupported(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
class UnknownModelType:
pass
with pytest.raises(ValueError, match="Unsupported model type"):
factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type]
def test_get_models_filters_by_provider_and_model_type(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
llm = AIModelEntity(
model="m1",
label=I18nObject(en_US="m1"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
embed = AIModelEntity(
model="e1",
label=I18nObject(en_US="e1"),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
openai = _provider_entity(
provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed]
)
anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm])
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
]
# ModelType filter picks only matching models
providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING)
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert [m.model for m in providers[0].models] == ["e1"]
# Provider filter excludes others
providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM)
assert len(providers) == 1
assert providers[0].provider == "langgenius/anthropic/anthropic"
def test_get_models_provider_filter_skips_non_matching(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
openai = _provider_entity(provider="openai")
anthropic = _provider_entity(provider="anthropic")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
]
providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM)
assert providers == []
def test_get_provider_icon_fetches_asset_and_returns_mime_type(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"),
icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
class FakePluginAssetManager:
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
assert tenant_id == "tenant"
return f"bytes:{id}".encode()
import core.plugin.impl.asset as asset_module
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
data, mime = factory.get_provider_icon("openai", "icon_small", "en_US")
assert data == b"bytes:icon.png"
assert mime == "image/png"
data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans")
assert data == b"bytes:dark-zh.svg"
assert mime == "image/svg+xml"
def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"),
icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
class FakePluginAssetManager:
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
return id.encode()
import core.plugin.impl.asset as asset_module
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans")
assert data == b"icon-zh.png"
data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US")
assert data == b"dark-en.svg"
def test_get_provider_icon_raises_for_missing_icons(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(provider="langgenius/openai/openai")
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="does not have small icon"):
factory.get_provider_icon("openai", "icon_small", "en_US")
with pytest.raises(ValueError, match="does not have small dark icon"):
factory.get_provider_icon("openai", "icon_small_dark", "en_US")
def test_get_provider_icon_raises_for_unsupported_icon_type(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="", zh_Hans=""),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="Unsupported icon type"):
factory.get_provider_icon("openai", "nope", "en_US")
def test_get_provider_icon_raises_when_file_name_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="", zh_Hans=""),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="does not have icon"):
factory.get_provider_icon("openai", "icon_small", "en_US")
def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case(
factory: ModelProviderFactory,
) -> None:
plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google")
assert plugin_id == "langgenius/gemini"
assert provider_name == "google"

View File

@ -0,0 +1,201 @@
import pytest
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormOption,
FormShowOnObject,
FormType,
)
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
class TestCommonValidator:
def test_validate_credential_form_schema_required_missing(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
)
with pytest.raises(ValueError, match="Variable api_key is required"):
validator._validate_credential_form_schema(schema, {})
def test_validate_credential_form_schema_not_required_missing_with_default(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key"),
type=FormType.TEXT_INPUT,
required=False,
default="default_value",
)
assert validator._validate_credential_form_schema(schema, {}) == "default_value"
def test_validate_credential_form_schema_not_required_missing_no_default(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False
)
assert validator._validate_credential_form_schema(schema, {}) is None
def test_validate_credential_form_schema_max_length_exceeded(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5
)
with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"):
validator._validate_credential_form_schema(schema, {"api_key": "123456"})
def test_validate_credential_form_schema_not_string(self):
validator = CommonValidator()
schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT)
with pytest.raises(ValueError, match="Variable api_key should be string"):
validator._validate_credential_form_schema(schema, {"api_key": 123})
def test_validate_credential_form_schema_select_invalid_option(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="mode",
label=I18nObject(en_US="Mode"),
type=FormType.SELECT,
options=[
FormOption(label=I18nObject(en_US="Fast"), value="fast"),
FormOption(label=I18nObject(en_US="Slow"), value="slow"),
],
)
with pytest.raises(ValueError, match="Variable mode is not in options"):
validator._validate_credential_form_schema(schema, {"mode": "medium"})
def test_validate_credential_form_schema_select_valid_option(self):
validator = CommonValidator()
schema = CredentialFormSchema(
variable="mode",
label=I18nObject(en_US="Mode"),
type=FormType.SELECT,
options=[
FormOption(label=I18nObject(en_US="Fast"), value="fast"),
FormOption(label=I18nObject(en_US="Slow"), value="slow"),
],
)
assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast"
def test_validate_credential_form_schema_switch_invalid(self):
validator = CommonValidator()
schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH)
with pytest.raises(ValueError, match="Variable enabled should be true or false"):
validator._validate_credential_form_schema(schema, {"enabled": "maybe"})
def test_validate_credential_form_schema_switch_valid(self):
validator = CommonValidator()
schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH)
assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True
assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False
def test_validate_and_filter_credential_form_schemas_with_show_on(self):
validator = CommonValidator()
schemas = [
CredentialFormSchema(
variable="auth_type",
label=I18nObject(en_US="Auth Type"),
type=FormType.SELECT,
options=[
FormOption(label=I18nObject(en_US="API Key"), value="api_key"),
FormOption(label=I18nObject(en_US="OAuth"), value="oauth"),
],
),
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key"),
type=FormType.TEXT_INPUT,
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
),
CredentialFormSchema(
variable="client_id",
label=I18nObject(en_US="Client ID"),
type=FormType.TEXT_INPUT,
show_on=[FormShowOnObject(variable="auth_type", value="oauth")],
),
]
# Case 1: auth_type = api_key
credentials = {"auth_type": "api_key", "api_key": "my_secret"}
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
assert "auth_type" in result
assert "api_key" in result
assert "client_id" not in result
assert result["api_key"] == "my_secret"
# Case 2: auth_type = oauth
credentials = {"auth_type": "oauth", "client_id": "my_client"}
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
# Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation.
# Since 'oauth' is not an empty string, it is in result.
assert "auth_type" in result
assert "api_key" not in result
assert "client_id" in result
assert result["client_id"] == "my_client"
def test_validate_and_filter_show_on_missing_variable(self):
validator = CommonValidator()
schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key"),
type=FormType.TEXT_INPUT,
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
)
]
# auth_type is missing in credentials, so api_key should be filtered out
result = validator._validate_and_filter_credential_form_schemas(schemas, {})
assert result == {}
def test_validate_and_filter_show_on_mismatch_value(self):
validator = CommonValidator()
schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key"),
type=FormType.TEXT_INPUT,
show_on=[FormShowOnObject(variable="auth_type", value="api_key")],
)
]
# auth_type is oauth, which doesn't match show_on
result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"})
assert result == {}
def test_validate_and_filter_multiple_show_on(self):
validator = CommonValidator()
schemas = [
CredentialFormSchema(
variable="target",
label=I18nObject(en_US="Target"),
type=FormType.TEXT_INPUT,
show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")],
)
]
# Both match
assert "target" in validator._validate_and_filter_credential_form_schemas(
schemas, {"v1": "a", "v2": "b", "target": "val"}
)
# One mismatch
assert "target" not in validator._validate_and_filter_credential_form_schemas(
schemas, {"v1": "a", "v2": "c", "target": "val"}
)
# One missing
assert "target" not in validator._validate_and_filter_credential_form_schemas(
schemas, {"v1": "a", "target": "val"}
)
def test_validate_and_filter_skips_falsy_results(self):
validator = CommonValidator()
schemas = [
CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH),
CredentialFormSchema(
variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False
),
]
# Result of false switch is False. if result: is false. Not added.
# Result of empty string is "", if result: is false. Not added.
credentials = {"enabled": "false", "empty_str": ""}
result = validator._validate_and_filter_credential_form_schemas(schemas, credentials)
assert "enabled" not in result
assert "empty_str" not in result

View File

@ -0,0 +1,233 @@
import pytest
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FieldModelSchema,
FormOption,
FormShowOnObject,
FormType,
ModelCredentialSchema,
)
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
def test_validate_and_filter_with_none_schema():
validator = ModelCredentialSchemaValidator(ModelType.LLM, None)
with pytest.raises(ValueError, match="Model credential schema is None"):
validator.validate_and_filter({})
def test_validate_and_filter_success():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API Key"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="optional_field",
label=I18nObject(en_US="Optional", zh_Hans="可选"),
type=FormType.TEXT_INPUT,
required=False,
default="default_val",
),
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
credentials = {"api_key": "sk-123456"}
result = validator.validate_and_filter(credentials)
assert result["api_key"] == "sk-123456"
assert result["optional_field"] == "default_val"
assert credentials["__model_type"] == ModelType.LLM.value
def test_validate_and_filter_with_show_on():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True
),
CredentialFormSchema(
variable="conditional_field",
label=I18nObject(en_US="Conditional", zh_Hans="条件"),
type=FormType.TEXT_INPUT,
required=True,
show_on=[FormShowOnObject(variable="mode", value="advanced")],
),
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
# mode is 'simple', conditional_field should be filtered out
credentials = {"mode": "simple", "conditional_field": "secret"}
result = validator.validate_and_filter(credentials)
assert "conditional_field" not in result
assert result["mode"] == "simple"
# mode is 'advanced', conditional_field should be kept
credentials = {"mode": "advanced", "conditional_field": "secret"}
result = validator.validate_and_filter(credentials)
assert result["conditional_field"] == "secret"
assert result["mode"] == "advanced"
# show_on variable missing in credentials
credentials = {"conditional_field": "secret"} # mode missing
with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema
validator.validate_and_filter(credentials)
def test_validate_and_filter_show_on_missing_trigger_var():
# specifically test all_show_on_match = False when variable not in credentials
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="optional_trigger",
label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"),
type=FormType.TEXT_INPUT,
required=False,
),
CredentialFormSchema(
variable="conditional_field",
label=I18nObject(en_US="Conditional", zh_Hans="条件"),
type=FormType.TEXT_INPUT,
required=False,
show_on=[FormShowOnObject(variable="optional_trigger", value="active")],
),
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
# optional_trigger missing, conditional_field should be skipped
result = validator.validate_and_filter({"conditional_field": "val"})
assert "conditional_field" not in result
def test_common_validator_logic_required():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API Key"),
type=FormType.SECRET_INPUT,
required=True,
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
with pytest.raises(ValueError, match="Variable api_key is required"):
validator.validate_and_filter({})
with pytest.raises(ValueError, match="Variable api_key is required"):
validator.validate_and_filter({"api_key": ""})
def test_common_validator_logic_max_length():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="key",
label=I18nObject(en_US="Key", zh_Hans="Key"),
type=FormType.TEXT_INPUT,
required=True,
max_length=5,
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
with pytest.raises(ValueError, match="Variable key length should not be greater than 5"):
validator.validate_and_filter({"key": "123456"})
def test_common_validator_logic_invalid_type():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
with pytest.raises(ValueError, match="Variable key should be string"):
validator.validate_and_filter({"key": 123})
def test_common_validator_logic_switch():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="enabled",
label=I18nObject(en_US="Enabled", zh_Hans="启用"),
type=FormType.SWITCH,
required=True,
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
result = validator.validate_and_filter({"enabled": "true"})
assert result["enabled"] is True
result = validator.validate_and_filter({"enabled": "false"})
assert "enabled" not in result
with pytest.raises(ValueError, match="Variable enabled should be true or false"):
validator.validate_and_filter({"enabled": "not_a_bool"})
def test_common_validator_logic_options():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="choice",
label=I18nObject(en_US="Choice", zh_Hans="选择"),
type=FormType.SELECT,
required=True,
options=[
FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"),
FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"),
],
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
result = validator.validate_and_filter({"choice": "a"})
assert result["choice"] == "a"
with pytest.raises(ValueError, match="Variable choice is not in options"):
validator.validate_and_filter({"choice": "c"})
def test_validate_and_filter_optional_no_default():
schema = ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")),
credential_form_schemas=[
CredentialFormSchema(
variable="optional",
label=I18nObject(en_US="Optional", zh_Hans="可选"),
type=FormType.TEXT_INPUT,
required=False,
)
],
)
validator = ModelCredentialSchemaValidator(ModelType.LLM, schema)
result = validator.validate_and_filter({})
assert "optional" not in result

View File

@ -0,0 +1,72 @@
import pytest
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
ProviderCredentialSchemaValidator,
)
class TestProviderCredentialSchemaValidator:
def test_validate_and_filter_success(self):
# Setup schema
schema = ProviderCredentialSchema(
credential_form_schemas=[
CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
),
CredentialFormSchema(
variable="endpoint",
label=I18nObject(en_US="Endpoint"),
type=FormType.TEXT_INPUT,
required=False,
default="https://api.example.com",
),
]
)
validator = ProviderCredentialSchemaValidator(schema)
# Test valid credentials
credentials = {"api_key": "my-secret-key"}
result = validator.validate_and_filter(credentials)
assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"}
def test_validate_and_filter_missing_required(self):
# Setup schema
schema = ProviderCredentialSchema(
credential_form_schemas=[
CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
)
]
)
validator = ProviderCredentialSchemaValidator(schema)
# Test missing required credentials
with pytest.raises(ValueError, match="Variable api_key is required"):
validator.validate_and_filter({})
def test_validate_and_filter_extra_fields_filtered(self):
# Setup schema
schema = ProviderCredentialSchema(
credential_form_schemas=[
CredentialFormSchema(
variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True
)
]
)
validator = ProviderCredentialSchemaValidator(schema)
# Test credentials with extra fields
credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"}
result = validator.validate_and_filter(credentials)
assert "api_key" in result
assert "extra_field" not in result
assert result == {"api_key": "my-secret-key"}
def test_init(self):
schema = ProviderCredentialSchema(credential_form_schemas=[])
validator = ProviderCredentialSchemaValidator(schema)
assert validator.provider_credential_schema == schema

View File

@ -0,0 +1,231 @@
import dataclasses
import datetime
from collections import deque
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path, PurePath
from re import compile
from typing import Any
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from pydantic import BaseModel, ConfigDict
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from pydantic_core import Url
from pydantic_extra_types.color import Color
from dify_graph.model_runtime.utils.encoders import (
_model_dump,
decimal_encoder,
generate_encoders_by_class_tuples,
isoformat,
jsonable_encoder,
)
class MockEnum(Enum):
A = "a"
B = "b"
class MockPydanticModel(BaseModel):
model_config = ConfigDict(populate_by_name=True)
name: str
age: int
@dataclasses.dataclass
class MockDataclass:
name: str
value: Any
class MockWithDict:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data.items())
class MockWithVars:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class TestEncoders:
def test_model_dump(self):
model = MockPydanticModel(name="test", age=20)
result = _model_dump(model)
assert result == {"name": "test", "age": 20}
def test_isoformat(self):
d = datetime.date(2023, 1, 1)
assert isoformat(d) == "2023-01-01"
t = datetime.time(12, 0, 0)
assert isoformat(t) == "12:00:00"
def test_decimal_encoder(self):
assert decimal_encoder(Decimal("1.0")) == 1.0
assert decimal_encoder(Decimal(1)) == 1
assert decimal_encoder(Decimal("1.5")) == 1.5
assert decimal_encoder(Decimal(0)) == 0
assert decimal_encoder(Decimal(-1)) == -1
def test_generate_encoders_by_class_tuples(self):
type_map = {int: str, float: str, str: int}
result = generate_encoders_by_class_tuples(type_map)
assert result[str] == (int, float)
assert result[int] == (str,)
def test_jsonable_encoder_basic_types(self):
assert jsonable_encoder("string") == "string"
assert jsonable_encoder(123) == 123
assert jsonable_encoder(1.23) == 1.23
assert jsonable_encoder(None) is None
def test_jsonable_encoder_pydantic(self):
model = MockPydanticModel(name="test", age=20)
assert jsonable_encoder(model) == {"name": "test", "age": 20}
def test_jsonable_encoder_pydantic_root(self):
# Manually create a mock that behaves like a model with __root__
# because Pydantic v2 handles root differently, but the code checks for "__root__"
model = MagicMock(spec=BaseModel)
# _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...)
model.model_dump.return_value = {"__root__": [1, 2, 3]}
assert jsonable_encoder(model) == [1, 2, 3]
def test_jsonable_encoder_dataclass(self):
obj = MockDataclass(name="test", value=1)
assert jsonable_encoder(obj) == {"name": "test", "value": 1}
# Test dataclass type (should not be treated as instance)
# It should fall back to vars() or dict() or at least not crash
with pytest.raises(ValueError):
jsonable_encoder(MockDataclass)
def test_jsonable_encoder_enum(self):
assert jsonable_encoder(MockEnum.A) == "a"
def test_jsonable_encoder_path(self):
assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test"
assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test"
def test_jsonable_encoder_decimal(self):
# In jsonable_encoder, Decimal is formatted as string via format(obj, "f")
assert jsonable_encoder(Decimal("1.23")) == "1.23"
assert jsonable_encoder(Decimal("1.000")) == "1.000"
def test_jsonable_encoder_dict(self):
d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"}
assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]}
assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"}
d_with_none = {"a": 1, "b": None}
assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1}
assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None}
def test_jsonable_encoder_collections(self):
assert jsonable_encoder([1, 2]) == [1, 2]
assert jsonable_encoder((1, 2)) == [1, 2]
assert jsonable_encoder({1, 2}) == [1, 2]
assert jsonable_encoder(frozenset([1, 2])) == [1, 2]
assert jsonable_encoder(deque([1, 2])) == [1, 2]
def gen():
yield 1
yield 2
assert jsonable_encoder(gen()) == [1, 2]
def test_jsonable_encoder_custom_encoder(self):
custom = {int: lambda x: str(x + 1)}
assert jsonable_encoder(1, custom_encoder=custom) == "2"
# Test subclass matching for custom encoder
class SubInt(int):
pass
assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2"
def test_jsonable_encoder_special_types(self):
# These hit ENCODERS_BY_TYPE or encoders_by_class_tuples
assert jsonable_encoder(b"bytes") == "bytes"
assert jsonable_encoder(Color("red")) == "red"
dt = datetime.datetime(2023, 1, 1, 12, 0, 0)
assert jsonable_encoder(dt) == dt.isoformat()
date = datetime.date(2023, 1, 1)
assert jsonable_encoder(date) == date.isoformat()
time = datetime.time(12, 0, 0)
assert jsonable_encoder(time) == time.isoformat()
td = datetime.timedelta(seconds=60)
assert jsonable_encoder(td) == 60.0
assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1"
assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24"
assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24"
assert jsonable_encoder(IPv6Address("::1")) == "::1"
assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128"
assert jsonable_encoder(IPv6Network("::/128")) == "::/128"
assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test <test@example.com>"
assert jsonable_encoder(compile("abc")) == "abc"
# Secret types
# Check what they actually return in this environment
res_bytes = jsonable_encoder(SecretBytes(b"secret"))
assert "**********" in res_bytes
res_str = jsonable_encoder(SecretStr("secret"))
assert res_str == "**********"
u = UUID("12345678-1234-5678-1234-567812345678")
assert jsonable_encoder(u) == str(u)
url = AnyUrl("https://example.com")
assert jsonable_encoder(url) == "https://example.com/"
purl = Url("https://example.com")
assert jsonable_encoder(purl) == "https://example.com/"
def test_jsonable_encoder_fallback(self):
# dict(obj) success
obj_dict = MockWithDict({"a": 1})
assert jsonable_encoder(obj_dict) == {"a": 1}
# vars(obj) success
obj_vars = MockWithVars(x=10, y=20)
assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20}
# error fallback
class ReallyUnserializable:
__slots__ = ["__weakref__"] # No __dict__
def __iter__(self):
raise TypeError("not iterable")
with pytest.raises(ValueError) as exc:
jsonable_encoder(ReallyUnserializable())
assert "not iterable" in str(exc.value)
def test_jsonable_encoder_nested(self):
data = {
"model": MockPydanticModel(name="test", age=20),
"list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}],
"set": {1, 2},
}
expected = {
"model": {"name": "test", "age": 20},
"list": ["1.1", {"a": "/tmp"}],
"set": [1, 2],
}
assert jsonable_encoder(data) == expected