chore: add Type to test (#35942)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-05-09 12:16:22 +09:00 committed by GitHub
parent e03eb3a76c
commit 140ad6ba4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
200 changed files with 1497 additions and 1264 deletions

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
@ -31,7 +32,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from models import EndUser
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
def test_get_user_id_from_message_data_no_end_user(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = None
@ -39,7 +40,7 @@ def test_get_user_id_from_message_data_no_end_user(monkeypatch):
assert get_user_id_from_message_data(message_data) == "account_id"
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
def test_get_user_id_from_message_data_with_end_user(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
@ -57,7 +58,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
assert get_user_id_from_message_data(message_data) == "session_id"
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
@ -111,7 +112,7 @@ def test_get_workflow_node_status():
assert status.status_code == StatusCode.UNSET
def test_create_links_from_trace_id(monkeypatch):
def test_create_links_from_trace_id(monkeypatch: pytest.MonkeyPatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
import dify_trace_aliyun.data_exporter.traceclient

View File

@ -40,7 +40,7 @@ def langfuse_config():
@pytest.fixture
def trace_instance(langfuse_config, monkeypatch):
def trace_instance(langfuse_config, monkeypatch: pytest.MonkeyPatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
@ -49,7 +49,7 @@ def trace_instance(langfuse_config, monkeypatch):
return instance
def test_init(langfuse_config, monkeypatch):
def test_init(langfuse_config, monkeypatch: pytest.MonkeyPatch):
mock_langfuse = MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -64,7 +64,7 @@ def test_init(langfuse_config, monkeypatch):
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -114,7 +114,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Setup trace info
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
@ -218,7 +218,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
assert other_span.level == LevelEnum.ERROR
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
@ -259,7 +259,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
@ -287,7 +287,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -331,7 +331,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
assert gen_data.usage.total == 30
def test_message_trace_with_end_user(trace_instance, monkeypatch):
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -636,7 +636,7 @@ def test_langfuse_trace_entity_with_list_dict_input():
assert data.input[0]["content"] == "hello"
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
# Setup trace info to trigger LLM node usage extraction
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",

View File

@ -35,7 +35,7 @@ def langsmith_config():
@pytest.fixture
def trace_instance(langsmith_config, monkeypatch):
def trace_instance(langsmith_config, monkeypatch: pytest.MonkeyPatch):
# Mock LangSmith client
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
@ -44,7 +44,7 @@ def trace_instance(langsmith_config, monkeypatch):
return instance
def test_init(langsmith_config, monkeypatch):
def test_init(langsmith_config, monkeypatch: pytest.MonkeyPatch):
mock_client_class = MagicMock()
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -57,7 +57,7 @@ def test_init(langsmith_config, monkeypatch):
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -107,7 +107,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace(trace_instance, monkeypatch):
def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Setup trace info
workflow_data = MagicMock()
workflow_data.created_at = _dt()
@ -223,7 +223,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
assert call_args[4].run_type == LangSmithRunType.retriever
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
def test_workflow_trace_no_start_time(trace_instance, monkeypatch: pytest.MonkeyPatch):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
@ -266,7 +266,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
assert trace_instance.add_run.called
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.trace_id = "trace-1"
trace_info.message_id = None
@ -290,7 +290,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace(trace_instance, monkeypatch):
def test_message_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -516,7 +516,7 @@ def test_update_run_error(trace_instance):
trace_instance.update_run(update_data)
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)

View File

@ -614,7 +614,7 @@ class TestMessageTrace:
span.set_status.assert_called_once()
span.add_event.assert_called_once()
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch):
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch: pytest.MonkeyPatch):
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"

View File

@ -35,7 +35,7 @@ def opik_config():
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
def trace_instance(opik_config, monkeypatch: pytest.MonkeyPatch):
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
@ -65,7 +65,7 @@ def test_prepare_opik_uuid():
assert result is not None
def test_init(opik_config, monkeypatch):
def test_init(opik_config, monkeypatch: pytest.MonkeyPatch):
mock_opik = MagicMock()
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -82,7 +82,7 @@ def test_init(opik_config, monkeypatch):
assert instance.project == opik_config.project
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -132,7 +132,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
@ -221,7 +221,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
assert trace_instance.add_span.call_count >= 1
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
@ -265,7 +265,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
trace_instance.add_trace.assert_called_once()
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
tenant_id="tenant-1",
@ -293,7 +293,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
@ -340,7 +340,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
trace_instance.add_span.assert_called_once()
def test_message_trace_with_end_user(trace_instance, monkeypatch):
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
message_data.from_account_id = "acc-1"
@ -614,7 +614,7 @@ def test_get_project_url_error(trace_instance):
trace_instance.get_project_url()
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
trace_info = WorkflowTraceInfo(
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",

View File

@ -267,14 +267,14 @@ class TestInit:
with pytest.raises(ValueError, match="Weave login failed"):
WeaveDataTrace(config)
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch):
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
"""Test FILES_URL is read from environment."""
monkeypatch.setenv("FILES_URL", "http://files.example.com")
config = _make_weave_config()
instance = WeaveDataTrace(config)
assert instance.file_base_url == "http://files.example.com"
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch):
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
"""Test FILES_URL defaults to http://127.0.0.1:5001."""
monkeypatch.delenv("FILES_URL", raising=False)
config = _make_weave_config()
@ -302,7 +302,7 @@ class TestGetProjectUrl:
url = instance.get_project_url()
assert url == "https://wandb.ai/my-project"
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch):
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Raises ValueError when exception occurs in get_project_url."""
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
@ -583,7 +583,7 @@ class TestFinishCall:
class TestWorkflowTrace:
def _setup_repo(self, monkeypatch, nodes=None):
def _setup_repo(self, monkeypatch: pytest.MonkeyPatch, nodes=None):
"""Helper to patch session/repo dependencies."""
if nodes is None:
nodes = []
@ -599,7 +599,7 @@ class TestWorkflowTrace:
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
return repo
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace with no nodes and no message_id."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -614,7 +614,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 1
assert trace_instance.finish_call.call_count == 1
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace with message_id creates both message and workflow runs."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -629,7 +629,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 2
assert trace_instance.finish_call.call_count == 2
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch):
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace iterates node executions and creates node runs."""
node = _make_node(
id="node-1",
@ -652,7 +652,7 @@ class TestWorkflowTrace:
# workflow run + node run = 2 calls
assert trace_instance.start_call.call_count == 2
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch):
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""LLM node uses process_data prompts as inputs."""
node = _make_node(
node_type=BuiltinNodeTypes.LLM,
@ -680,7 +680,7 @@ class TestWorkflowTrace:
# The key "messages" should be present (validator transforms the list)
assert "messages" in node_run.inputs
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch):
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Non-LLM node uses node_execution.inputs directly."""
node = _make_node(
node_type=BuiltinNodeTypes.TOOL,
@ -701,7 +701,7 @@ class TestWorkflowTrace:
node_run = node_call_args[0][0]
assert node_run.inputs.get("tool_input") == "val"
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Raises ValueError when app_id is missing from metadata."""
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
@ -714,7 +714,7 @@ class TestWorkflowTrace:
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch):
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""start_time defaults to datetime.now() when None."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -727,7 +727,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 1
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch):
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Node with created_at=None uses datetime.now()."""
node = _make_node(created_at=None, elapsed_time=0.5)
self._setup_repo(monkeypatch, nodes=[node])
@ -740,7 +740,7 @@ class TestWorkflowTrace:
trace_instance.workflow_trace(trace_info)
assert trace_instance.start_call.call_count == 2
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch):
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Chat mode LLM node adds ls_provider and ls_model_name to attributes."""
node = _make_node(
node_type=BuiltinNodeTypes.LLM,
@ -765,7 +765,7 @@ class TestWorkflowTrace:
assert node_run.attributes.get("ls_provider") == "openai"
assert node_run.attributes.get("ls_model_name") == "gpt-4"
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch):
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Nodes are sorted by created_at before processing."""
node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2))
node2 = _make_node(id="node-a", created_at=_dt())
@ -799,7 +799,7 @@ class TestMessageTrace:
trace_instance.message_trace(trace_info)
trace_instance.start_call.assert_not_called()
def test_basic_message_trace(self, trace_instance, monkeypatch):
def test_basic_message_trace(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace creates message run and llm child run."""
monkeypatch.setattr(
"dify_trace_weave.weave_trace.db.session.get",
@ -816,7 +816,7 @@ class TestMessageTrace:
assert trace_instance.start_call.call_count == 2
assert trace_instance.finish_call.call_count == 2
def test_message_trace_with_file_data(self, trace_instance, monkeypatch):
def test_message_trace_with_file_data(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace appends file URL to file_list."""
file_data = MagicMock()
file_data.url = "path/to/file.png"
@ -839,7 +839,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert "http://files.test/path/to/file.png" in message_run.file_list
def test_message_trace_with_end_user(self, trace_instance, monkeypatch):
def test_message_trace_with_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace looks up end user and sets end_user_id attribute."""
end_user = MagicMock()
end_user.session_id = "session-xyz"
@ -862,7 +862,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert message_run.attributes.get("end_user_id") == "session-xyz"
def test_message_trace_no_end_user(self, trace_instance, monkeypatch):
def test_message_trace_no_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace handles when from_end_user_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
@ -880,7 +880,7 @@ class TestMessageTrace:
trace_instance.message_trace(trace_info)
assert trace_instance.start_call.call_count == 2
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch):
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""trace_id falls back to message_id when trace_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
@ -895,7 +895,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert message_run.id == "msg-1"
def test_message_trace_file_list_none(self, trace_instance, monkeypatch):
def test_message_trace_file_list_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace handles file_list=None gracefully."""
mock_db = MagicMock()
mock_db.session.get.return_value = None

View File

@ -20,7 +20,7 @@ def test_validate_distance_function_rejects_unsupported_values():
factory._validate_distance_function("dot_product")
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch: pytest.MonkeyPatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
@ -45,7 +45,7 @@ def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch: pytest.MonkeyPatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-2",

View File

@ -83,7 +83,7 @@ def test_get_type_is_analyticdb():
assert vector.get_type() == "analyticdb"
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch: pytest.MonkeyPatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
@ -109,7 +109,7 @@ def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
assert dataset.index_struct is not None
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
def test_factory_builds_sql_config_when_host_is_present(monkeypatch: pytest.MonkeyPatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None

View File

@ -24,7 +24,7 @@ def _request_class(name: str):
return _Request
def _install_openapi_stubs(monkeypatch):
def _install_openapi_stubs(monkeypatch: pytest.MonkeyPatch):
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
gpdb_package.__path__ = []
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
@ -130,7 +130,7 @@ def test_openapi_config_to_client_params():
assert params["read_timeout"] == 60000
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
initialize_mock = MagicMock()
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
@ -145,7 +145,7 @@ def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
initialize_mock.assert_called_once_with()
def test_initialize_skips_when_cached(monkeypatch):
def test_initialize_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -164,7 +164,7 @@ def test_initialize_skips_when_cached(monkeypatch):
vector._create_namespace_if_not_exists.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -184,7 +184,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
openapi_module.redis_client.set.assert_called_once()
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
def test_initialize_vector_database_calls_openapi_client(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -199,7 +199,7 @@ def test_initialize_vector_database_calls_openapi_client(monkeypatch):
assert request.manager_account_password == "password"
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
def test_create_namespace_creates_when_namespace_not_found(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -211,7 +211,7 @@ def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
vector._client.create_namespace.assert_called_once()
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -222,7 +222,7 @@ def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
vector._create_namespace_if_not_exists()
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
def test_create_namespace_noop_when_namespace_exists(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -234,7 +234,7 @@ def test_create_namespace_noop_when_namespace_exists(monkeypatch):
vector._client.create_namespace.assert_not_called()
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
@ -255,7 +255,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
openapi_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -274,7 +274,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
vector._client.create_collection.assert_not_called()
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
@ -293,7 +293,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
vector.create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):
def test_openapi_add_delete_and_search_methods(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -348,7 +348,7 @@ def test_openapi_add_delete_and_search_methods(monkeypatch):
assert docs_by_text[0].page_content == "high"
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
def test_text_exists_returns_false_when_matches_empty(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -361,7 +361,7 @@ def test_text_exists_returns_false_when_matches_empty(monkeypatch):
assert vector.text_exists("missing-id") is False
def test_openapi_delete_success(monkeypatch):
def test_openapi_delete_success(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -372,7 +372,7 @@ def test_openapi_delete_success(monkeypatch):
vector._client.delete_collection.assert_called_once()
def test_openapi_delete_propagates_errors(monkeypatch):
def test_openapi_delete_propagates_errors(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"

View File

@ -53,7 +53,7 @@ def test_sql_config_rejects_min_connection_greater_than_max_connection():
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_initialize_skips_when_cache_exists(monkeypatch):
def test_initialize_skips_when_cache_exists(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -70,7 +70,7 @@ def test_initialize_skips_when_cache_exists(monkeypatch):
vector._initialize_vector_database.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -88,7 +88,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
sql_module.redis_client.set.assert_called_once()
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -119,7 +119,7 @@ def test_get_cursor_context_manager_handles_connection_lifecycle():
pool.putconn.assert_called_once_with(connection)
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
@ -273,7 +273,7 @@ def test_delete_drops_table():
cursor.execute.assert_called_once()
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch: pytest.MonkeyPatch):
config = AnalyticdbVectorBySqlConfig(**_config_values())
created_pool = MagicMock()
@ -288,7 +288,7 @@ def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypat
assert vector.pool is created_pool
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -326,7 +326,7 @@ def test_initialize_vector_database_handles_existing_database_and_search_config(
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -353,7 +353,7 @@ def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(mon
worker_connection.rollback.assert_called_once()
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
@ -381,7 +381,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
sql_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"

View File

@ -121,7 +121,7 @@ def _build_fake_pymochow_modules():
@pytest.fixture
def baidu_module(monkeypatch):
def baidu_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import dify_vdb_baidu.baidu_vector as module
@ -254,7 +254,7 @@ def test_search_methods_delegate_to_database_table(baidu_module):
assert vector._get_search_res.call_count == 2
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch: pytest.MonkeyPatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
@ -279,7 +279,7 @@ def test_factory_initializes_collection_name_and_index_struct(baidu_module, monk
assert dataset.index_struct is not None
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch: pytest.MonkeyPatch):
init_client = MagicMock(return_value="client")
init_database = MagicMock(return_value="database")
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
@ -372,7 +372,7 @@ def test_get_search_result_handles_invalid_metadata_json(baidu_module):
assert "document_id" not in docs[0].metadata
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch: pytest.MonkeyPatch):
credentials = MagicMock(return_value="credentials")
configuration = MagicMock(return_value="configuration")
client_cls = MagicMock(return_value="client")
@ -411,7 +411,7 @@ def test_init_database_raises_for_unknown_create_database_error(baidu_module):
vector._init_database()
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
@ -460,7 +460,7 @@ def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypat
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
@ -493,7 +493,7 @@ def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypat
vector._create_table(3)
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
@ -524,7 +524,9 @@ def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module,
vector._create_table(3)
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(
baidu_module, monkeypatch: pytest.MonkeyPatch
):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",

View File

@ -44,7 +44,7 @@ def _build_fake_chroma_modules():
@pytest.fixture
def chroma_module(monkeypatch):
def chroma_module(monkeypatch: pytest.MonkeyPatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import dify_vdb_chroma.chroma_vector as module
@ -73,7 +73,7 @@ def test_chroma_config_to_params_builds_expected_payload(chroma_module):
assert params["settings"].chroma_client_auth_credentials == "credentials"
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -173,7 +173,7 @@ def test_search_by_full_text_returns_empty_list(chroma_module):
assert vector.search_by_full_text("query") == []
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch: pytest.MonkeyPatch):
factory = chroma_module.ChromaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None

View File

@ -45,7 +45,7 @@ def _build_fake_clickzetta_module():
@pytest.fixture
def clickzetta_module(monkeypatch):
def clickzetta_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import dify_vdb_clickzetta.clickzetta_vector as module
@ -218,7 +218,7 @@ def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
assert docs[0].metadata["score"] == 0.5
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
factory = clickzetta_module.ClickzettaVectorFactory()
dataset = SimpleNamespace(id="dataset-1")
@ -243,7 +243,7 @@ def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
clickzetta_module.ClickzettaConnectionPool._instance = None
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
@ -255,7 +255,7 @@ def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch
assert "username:instance:service:workspace:cluster:dify" in key
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -274,7 +274,7 @@ def test_connection_pool_create_connection_retries_and_configures(clickzetta_mod
pool._configure_connection.assert_called_once_with(connection)
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -318,7 +318,7 @@ def test_connection_pool_configure_connection_swallows_errors(clickzetta_module)
monkeypatch.undo()
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -360,7 +360,7 @@ def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monk
assert pool._shutdown is True
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
@ -384,7 +384,7 @@ def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module
pool._cleanup_expired_connections.assert_called_once()
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
pool.get_connection.return_value = "conn"
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
@ -405,7 +405,7 @@ def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypat
assert vector._ensure_connection() == "conn"
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
class _Thread:
def __init__(self, target, daemon):
self.target = target
@ -579,7 +579,7 @@ def test_create_inverted_index_branches(clickzetta_module):
vector._create_inverted_index(cursor)
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.batch_size = 2
@ -811,7 +811,7 @@ def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
assert pool._shutdown is True
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False

View File

@ -150,7 +150,7 @@ def _build_fake_couchbase_modules():
@pytest.fixture
def couchbase_module(monkeypatch):
def couchbase_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -194,7 +194,7 @@ def test_init_sets_cluster_handles(couchbase_module):
vector._cluster.wait_until_ready.assert_called_once()
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
def test_create_and_create_collection_branches(couchbase_module, monkeypatch: pytest.MonkeyPatch):
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
vector._collection_name = "collection_1"
vector._client_config = _config(couchbase_module)
@ -319,7 +319,7 @@ def test_search_methods_and_format_metadata(couchbase_module):
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
def test_delete_collection_and_factory(couchbase_module, monkeypatch: pytest.MonkeyPatch):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scopes = [
SimpleNamespace(collections=[SimpleNamespace(name="other")]),

View File

@ -28,7 +28,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def elasticsearch_ja_module(monkeypatch):
def elasticsearch_ja_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -39,7 +39,7 @@ def elasticsearch_ja_module(monkeypatch):
return importlib.reload(ja_module)
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -57,7 +57,7 @@ def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
elasticsearch_ja_module.redis_client.set.assert_not_called()
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -87,7 +87,7 @@ def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monk
elasticsearch_ja_module.redis_client.set.assert_called_once()
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def elasticsearch_module(monkeypatch):
def elasticsearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -287,7 +287,7 @@ def test_search_by_vector_and_full_text(elasticsearch_module):
assert "bool" in query
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -331,7 +331,7 @@ def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
elasticsearch_module.redis_client.set.assert_called_once()
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
factory = elasticsearch_module.ElasticSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_hologres_modules():
@pytest.fixture
def hologres_module(monkeypatch):
def hologres_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -266,7 +266,7 @@ def test_delete_handles_existing_and_missing_tables(hologres_module):
vector._client.drop_table.assert_called_once_with(vector.table_name)
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -281,7 +281,7 @@ def test_create_collection_returns_early_when_cache_hits(hologres_module, monkey
hologres_module.redis_client.set.assert_not_called()
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -313,7 +313,7 @@ def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatc
hologres_module.redis_client.set.assert_called_once()
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -331,7 +331,7 @@ def test_create_collection_raises_when_table_never_becomes_ready(hologres_module
hologres_module.redis_client.set.assert_not_called()
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch: pytest.MonkeyPatch):
factory = hologres_module.HologresVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -29,7 +29,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def huawei_module(monkeypatch):
def huawei_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -155,7 +155,7 @@ def test_search_by_vector_and_full_text(huawei_module):
assert docs[0].page_content == "text-hit"
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch: pytest.MonkeyPatch):
class FakeDocument:
def __init__(self, page_content, vector, metadata):
self.page_content = page_content
@ -185,7 +185,7 @@ def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch
assert docs == []
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
def test_create_and_create_collection_paths(huawei_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -218,7 +218,7 @@ def test_create_and_create_collection_paths(huawei_module, monkeypatch):
huawei_module.redis_client.set.assert_called_once()
def test_huawei_factory_branches(huawei_module, monkeypatch):
def test_huawei_factory_branches(huawei_module, monkeypatch: pytest.MonkeyPatch):
factory = huawei_module.HuaweiCloudVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -23,7 +23,7 @@ def _build_fake_iris_module():
@pytest.fixture
def iris_module(monkeypatch):
def iris_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import dify_vdb_iris.iris_vector as module
@ -249,7 +249,7 @@ def test_iris_vector_init_get_cursor_and_create(iris_module):
vector._create_collection.assert_called_once_with(2)
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch: pytest.MonkeyPatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module))
@ -297,7 +297,7 @@ def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
assert docs[0].metadata["score"] == pytest.approx(0.9)
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch: pytest.MonkeyPatch):
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", cfg)
@ -344,7 +344,7 @@ def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
assert vector_like.search_by_full_text("100%", top_k=1) == []
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch: pytest.MonkeyPatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))

View File

@ -47,7 +47,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def lindorm_module(monkeypatch):
def lindorm_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -100,7 +100,7 @@ def test_to_opensearch_params_and_init(lindorm_module):
assert vector_ugc._routing == "route"
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytest.MonkeyPatch):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
@ -301,7 +301,7 @@ def test_search_by_full_text_success_and_error(lindorm_module):
vector.search_by_full_text("hello")
def test_create_collection_paths(lindorm_module, monkeypatch):
def test_create_collection_paths(lindorm_module, monkeypatch: pytest.MonkeyPatch):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
with pytest.raises(ValueError, match="cannot be empty"):
@ -331,7 +331,7 @@ def test_create_collection_paths(lindorm_module, monkeypatch):
vector._client.indices.create.assert_not_called()
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
def test_lindorm_factory_branches(lindorm_module, monkeypatch: pytest.MonkeyPatch):
factory = lindorm_module.LindormVectorStoreFactory()
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")

View File

@ -32,7 +32,7 @@ def _build_fake_mo_vector_modules():
@pytest.fixture
def matrixone_module(monkeypatch):
def matrixone_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -70,7 +70,7 @@ def test_matrixone_config_validation(matrixone_module, field, value, message):
matrixone_module.MatrixoneConfig.model_validate(values)
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -86,7 +86,7 @@ def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module,
matrixone_module.redis_client.set.assert_called_once()
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -146,7 +146,7 @@ def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -165,7 +165,7 @@ def test_get_client_handles_full_text_index_creation_error(matrixone_module, mon
matrixone_module.redis_client.set.assert_not_called()
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch: pytest.MonkeyPatch):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
@ -224,7 +224,7 @@ def test_search_by_vector_builds_documents(matrixone_module):
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch: pytest.MonkeyPatch):
factory = matrixone_module.MatrixoneVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -99,7 +99,7 @@ def _build_fake_pymilvus_modules():
@pytest.fixture
def milvus_module(monkeypatch):
def milvus_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -327,7 +327,7 @@ def test_process_search_results_and_search_methods(milvus_module):
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -351,7 +351,7 @@ def test_create_collection_cache_and_existing_collection(milvus_module, monkeypa
milvus_module.redis_client.set.assert_called()
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -385,7 +385,7 @@ def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch)
assert call_kwargs["consistency_level"] == "Session"
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch: pytest.MonkeyPatch):
factory = milvus_module.MilvusVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_clickhouse_connect_module():
@pytest.fixture
def myscale_module(monkeypatch):
def myscale_module(monkeypatch: pytest.MonkeyPatch):
fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
@ -90,7 +90,7 @@ def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
vector._client.command.assert_not_called()
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch: pytest.MonkeyPatch):
factory = myscale_module.MyScaleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
@ -160,7 +160,7 @@ def test_create_collection_builds_expected_sql(myscale_module):
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch: pytest.MonkeyPatch):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [

View File

@ -53,7 +53,7 @@ def _build_fake_pyobvector_module():
@pytest.fixture
def oceanbase_module(monkeypatch):
def oceanbase_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
import dify_vdb_oceanbase.oceanbase_vector as module
@ -208,7 +208,7 @@ def test_create_delegates_to_collection_and_insert(oceanbase_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -234,7 +234,7 @@ def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_mod
vector.delete.assert_not_called()
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -271,7 +271,7 @@ def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, mo
oceanbase_module.redis_client.set.assert_called_once()
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
def test_create_collection_error_paths(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -308,7 +308,7 @@ def test_create_collection_error_paths(oceanbase_module, monkeypatch):
vector._create_collection()
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -517,7 +517,7 @@ def test_delete_success_and_exception(oceanbase_module):
vector.delete()
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
factory = oceanbase_module.OceanBaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
@pytest.fixture
def opengauss_module(monkeypatch):
def opengauss_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -88,7 +88,7 @@ def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_modu
opengauss_module.OpenGaussConfig.model_validate(values)
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -99,7 +99,7 @@ def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
assert vector.pool is pool
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -126,7 +126,7 @@ def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_called_once()
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -158,7 +158,7 @@ def test_search_by_vector_validates_top_k(opengauss_module):
vector.search_by_vector([0.1, 0.2], top_k=0)
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
@ -200,7 +200,7 @@ def test_create_calls_collection_insert_and_index(opengauss_module):
vector._create_index.assert_called_once_with(2)
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -220,7 +220,7 @@ def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_not_called()
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -245,7 +245,7 @@ def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, m
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
@ -342,7 +342,7 @@ def test_search_by_full_text_validates_top_k(opengauss_module):
vector.search_by_full_text("query", top_k=0)
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
@ -370,7 +370,7 @@ def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_called_once()
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch: pytest.MonkeyPatch):
factory = opengauss_module.OpenGaussFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -59,7 +59,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def opensearch_module(monkeypatch):
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -95,7 +95,7 @@ class TestOpenSearchConfig:
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch):
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch: pytest.MonkeyPatch):
class _Session:
def get_credentials(self):
return "creds"

View File

@ -58,7 +58,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def opensearch_module(monkeypatch):
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -116,7 +116,7 @@ def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch: pytest.MonkeyPatch):
class _Session:
def get_credentials(self):
return "creds"
@ -167,7 +167,7 @@ def test_init_and_create_delegate_calls(opensearch_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch: pytest.MonkeyPatch):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
@ -308,7 +308,7 @@ def test_search_by_full_text_and_filters(opensearch_module):
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -331,7 +331,7 @@ def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch)
opensearch_module.redis_client.set.assert_called()
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch: pytest.MonkeyPatch):
factory = opensearch_module.OpenSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -51,7 +51,7 @@ def _connection_with_cursor(cursor):
@pytest.fixture
def oracle_module(monkeypatch):
def oracle_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_oracle_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -94,7 +94,7 @@ def test_oracle_config_validation_autonomous_requirements(oracle_module):
)
def test_init_and_get_type(oracle_module, monkeypatch):
def test_init_and_get_type(oracle_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
@ -139,7 +139,7 @@ def test_numpy_converters_and_type_handlers(oracle_module):
assert out_float64.dtype == numpy.float64
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
connect = MagicMock(return_value="connection")
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
@ -173,7 +173,7 @@ def test_create_delegates_collection_and_insert(oracle_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
@ -279,7 +279,7 @@ def _fake_nltk_module(*, missing_data=False):
return nltk, nltk_corpus
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
@ -305,7 +305,7 @@ def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatc
assert "doc_id_0" in en_params
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector._get_connection = MagicMock()
@ -320,7 +320,7 @@ def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeyp
vector.search_by_full_text("english query")
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -346,7 +346,9 @@ def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
oracle_module.redis_client.set.assert_called_once()
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(
oracle_module, monkeypatch: pytest.MonkeyPatch
):
factory = oracle_module.OracleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -79,7 +79,7 @@ def _patch_both(monkeypatch, module, calls, execute_results=None):
@pytest.fixture
def pgvecto_module(monkeypatch):
def pgvecto_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pgvecto_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -126,7 +126,7 @@ def test_collection_base_has_expected_annotations(pgvecto_module):
assert {"id", "text", "meta", "vector"} <= set(annotations)
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -145,7 +145,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -169,7 +169,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module.redis_client.set.assert_called()
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
init_calls = []
runtime_calls = []
@ -241,7 +241,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -313,7 +313,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
assert vector.search_by_full_text("hello") == []
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
factory = module.PGVectoRSFactory()
dataset_with_index = SimpleNamespace(

View File

@ -336,7 +336,7 @@ def test_create_delegates_collection_creation_and_insert():
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch: pytest.MonkeyPatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
@ -387,7 +387,7 @@ def test_text_get_and_delete_methods():
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch: pytest.MonkeyPatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@ -464,7 +464,7 @@ def test_search_by_full_text_branches_for_bigm_and_standard():
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch: pytest.MonkeyPatch):
factory = pgvector_module.PGVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -121,7 +121,7 @@ def _build_fake_qdrant_modules():
@pytest.fixture
def qdrant_module(monkeypatch):
def qdrant_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_qdrant_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -170,7 +170,7 @@ def test_init_and_basic_behaviour(qdrant_module):
vector.add_texts.assert_called_once()
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
def test_create_collection_and_add_texts(qdrant_module, monkeypatch: pytest.MonkeyPatch):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
lock = MagicMock()
lock.__enter__.return_value = None
@ -288,7 +288,7 @@ def test_search_and_helper_methods(qdrant_module):
assert doc.page_content == "doc"
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
def test_qdrant_factory_paths(qdrant_module, monkeypatch: pytest.MonkeyPatch):
factory = qdrant_module.QdrantVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",

View File

@ -59,7 +59,7 @@ def _patch_both(monkeypatch, module, session):
@pytest.fixture
def relyt_module(monkeypatch):
def relyt_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_relyt_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -97,7 +97,7 @@ def test_relyt_config_validation(relyt_module, field, value, message):
relyt_module.RelytConfig.model_validate(values)
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch: pytest.MonkeyPatch):
engine = MagicMock()
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
@ -114,7 +114,7 @@ def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -142,7 +142,7 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
relyt_module.redis_client.set.assert_called_once()
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector._group_id = "group-1"
@ -212,7 +212,7 @@ def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
# 3. delete_by_ids translates to uuids
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -225,7 +225,7 @@ def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
# 4. text_exists True
def test_text_exists_true(relyt_module, monkeypatch):
def test_text_exists_true(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -236,7 +236,7 @@ def test_text_exists_true(relyt_module, monkeypatch):
# 5. text_exists False
def test_text_exists_false(relyt_module, monkeypatch):
def test_text_exists_false(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -284,7 +284,7 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session
def test_delete_drops_table(relyt_module, monkeypatch):
def test_delete_drops_table(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -295,7 +295,7 @@ def test_delete_drops_table(relyt_module, monkeypatch):
session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch: pytest.MonkeyPatch):
factory = relyt_module.RelytVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -77,7 +77,7 @@ def _build_fake_tablestore_module():
@pytest.fixture
def tablestore_module(monkeypatch):
def tablestore_module(monkeypatch: pytest.MonkeyPatch):
fake_module = _build_fake_tablestore_module()
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
@ -177,7 +177,7 @@ def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
vector._delete_table_if_exist.assert_called_once()
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch: pytest.MonkeyPatch):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
lock = MagicMock()
lock.__enter__.return_value = None
@ -289,7 +289,7 @@ def test_write_row_and_search_helpers(tablestore_module):
assert "score" not in docs[0].metadata
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch: pytest.MonkeyPatch):
factory = tablestore_module.TableStoreVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -136,7 +136,7 @@ def _build_fake_tencent_modules():
@pytest.fixture
def tencent_module(monkeypatch):
def tencent_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_tencent_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -187,7 +187,7 @@ def test_config_and_init_paths(tencent_module):
assert vector._enable_hybrid_search is False
def test_create_collection_branches(tencent_module, monkeypatch):
def test_create_collection_branches(tencent_module, monkeypatch: pytest.MonkeyPatch):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
lock = MagicMock()
@ -279,7 +279,7 @@ def test_create_add_delete_and_search_behaviour(tencent_module):
vector._client.drop_collection.assert_called_once()
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch: pytest.MonkeyPatch):
factory = tencent_module.TencentVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -46,7 +46,7 @@ def test_tidb_config_validation(tidb_module, field, value, message):
tidb_module.TiDBVectorConfig.model_validate(values)
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
def test_init_get_type_and_distance_func(tidb_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
@ -63,7 +63,7 @@ def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch: pytest.MonkeyPatch):
fake_tidb_vector = types.ModuleType("tidb_vector")
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
@ -107,7 +107,7 @@ def test_create_calls_collection_and_add_texts(tidb_module):
assert vector._dimension == 2
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -127,7 +127,7 @@ def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
tidb_module.redis_client.set.assert_not_called()
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -160,7 +160,7 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
tidb_module.redis_client.set.assert_called_once()
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch):
class _InsertStmt:
def __init__(self, table):
self.table = table
@ -198,7 +198,7 @@ def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
@pytest.fixture
def tidb_vector_with_session(tidb_module, monkeypatch):
def tidb_vector_with_session(tidb_module, monkeypatch: pytest.MonkeyPatch):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
@ -354,7 +354,7 @@ def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
# Test search_by_vector filters and scores
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch: pytest.MonkeyPatch):
session = MagicMock()
session.execute.return_value = [
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
@ -392,7 +392,7 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
# Test delete drops table
def test_delete_drops_table(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch: pytest.MonkeyPatch):
session = MagicMock()
session.execute.return_value = None
@ -413,7 +413,7 @@ def test_delete_drops_table(tidb_module, monkeypatch):
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch: pytest.MonkeyPatch):
factory = tidb_module.TiDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -36,7 +36,7 @@ def _build_fake_upstash_module():
@pytest.fixture
def upstash_module(monkeypatch):
def upstash_module(monkeypatch: pytest.MonkeyPatch):
# Remove patched modules if present
for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]:
if modname in sys.modules:
@ -65,7 +65,7 @@ def test_upstash_config_validation(upstash_module, field, value, message):
upstash_module.UpstashVectorConfig.model_validate(values)
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
def test_init_get_type_and_dimension(upstash_module, monkeypatch: pytest.MonkeyPatch):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
assert vector.get_type() == upstash_module.VectorType.UPSTASH
@ -162,7 +162,7 @@ def test_search_by_vector_filter_threshold_and_delete(upstash_module):
vector.index.reset.assert_called_once()
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch: pytest.MonkeyPatch):
factory = upstash_module.UpstashVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
@pytest.fixture
def vastbase_module(monkeypatch):
def vastbase_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -93,7 +93,7 @@ def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
)
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -114,7 +114,7 @@ def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
pool.putconn.assert_called_once_with(conn)
def test_create_and_add_texts(vastbase_module, monkeypatch):
def test_create_and_add_texts(vastbase_module, monkeypatch: pytest.MonkeyPatch):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
vector._create_collection = MagicMock()
@ -205,7 +205,7 @@ def test_search_by_vector_and_full_text(vastbase_module):
assert full_docs[0].page_content == "full-text"
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -240,7 +240,7 @@ def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeyp
vastbase_module.redis_client.set.assert_called()
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch: pytest.MonkeyPatch):
factory = vastbase_module.VastbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -79,7 +79,7 @@ def _build_fake_vikingdb_modules():
@pytest.fixture
def vikingdb_module(monkeypatch):
def vikingdb_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_vikingdb_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -117,7 +117,7 @@ def test_init_get_type_and_has_checks(vikingdb_module):
assert vector._has_index() is False
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -253,7 +253,7 @@ def test_delete_drops_index_and_collection_when_present(vikingdb_module):
vector._client.drop_collection.assert_not_called()
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
factory = vikingdb_module.VikingDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
@ -293,7 +293,9 @@ def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, mo
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
],
)
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
def test_vikingdb_factory_raises_when_required_config_missing(
vikingdb_module, monkeypatch: pytest.MonkeyPatch, field, message
):
factory = vikingdb_module.VikingDBVectorFactory()
dataset = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None

View File

@ -13,7 +13,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource
from models.enums import AppStatus, ConversationFromSource
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -28,7 +28,7 @@ class TestChatMessageApiPermissions:
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.status = AppStatus.NORMAL
return app
@pytest.fixture
@ -78,7 +78,7 @@ class TestChatMessageApiPermissions:
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
@ -130,7 +130,7 @@ class TestChatMessageApiPermissions:
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
role: TenantAccountRole,

View File

@ -14,7 +14,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.enums import AppStatus, FeedbackFromSource, FeedbackRating
from models.model import AppMode, MessageFeedback
from services.feedback_service import FeedbackService
@ -29,7 +29,7 @@ class TestFeedbackExportApi:
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.status = AppStatus.NORMAL
app.name = "Test App"
return app
@ -135,7 +135,7 @@ class TestFeedbackExportApi:
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
@ -167,7 +167,13 @@ class TestFeedbackExportApi:
mock_export_feedbacks.assert_called_once()
def test_feedback_export_csv_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
self,
test_client: FlaskClient,
auth_header,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
sample_feedback_data,
):
"""Test feedback export in CSV format."""
@ -202,7 +208,13 @@ class TestFeedbackExportApi:
assert "text/csv" in response.content_type
def test_feedback_export_json_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
self,
test_client: FlaskClient,
auth_header,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
sample_feedback_data,
):
"""Test feedback export in JSON format."""
@ -246,7 +258,7 @@ class TestFeedbackExportApi:
assert "application/json" in response.content_type
def test_feedback_export_with_filters(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
):
"""Test feedback export with various filters."""
@ -287,7 +299,7 @@ class TestFeedbackExportApi:
)
def test_feedback_export_invalid_date_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
):
"""Test feedback export with invalid date format."""
@ -312,7 +324,7 @@ class TestFeedbackExportApi:
assert "Parameter validation error" in response_json["error"]
def test_feedback_export_server_error(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
):
"""Test feedback export with server error."""

View File

@ -11,6 +11,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import AppStatus
from models.model import AppMode
from services.app_model_config_service import AppModelConfigService
@ -25,7 +26,7 @@ class TestModelConfigResourcePermissions:
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.status = AppStatus.NORMAL
app.app_model_config_id = str(uuid.uuid4())
return app
@ -73,7 +74,7 @@ class TestModelConfigResourcePermissions:
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
monkeypatch: pytest.MonkeyPatch,
mock_app_model,
mock_account,
role: TenantAccountRole,

View File

@ -1,5 +1,7 @@
from collections.abc import Generator
from pytest_mock import MockerFixture
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from graphon.node_events import StreamCompletedEvent
@ -19,7 +21,7 @@ def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
)
def test_stream_node_events_accumulates_variables(mocker):
def test_stream_node_events_accumulates_variables(mocker: MockerFixture):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
events = list(
DatasourceManager.stream_node_events(

View File

@ -1,3 +1,5 @@
from pytest_mock import MockerFixture
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.datasource.entities import DatasourceNodeData
@ -44,7 +46,7 @@ class _GP:
call_depth = 0
def test_node_integration_minimal_stream(mocker):
def test_node_integration_minimal_stream(mocker: MockerFixture):
sys_d = {
"sys": {
"datasource_type": "online_document",

View File

@ -2,6 +2,8 @@ import time
import uuid
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.node_factory import DifyNodeFactory
@ -71,7 +73,7 @@ def init_tool_node(config: dict):
return node
def test_tool_variable_invoke(monkeypatch):
def test_tool_variable_invoke(monkeypatch: pytest.MonkeyPatch):
node = init_tool_node(
config={
"id": "1",
@ -106,7 +108,7 @@ def test_tool_variable_invoke(monkeypatch):
assert item.node_run_result.outputs.get("text") is not None
def test_tool_mixed_invoke(monkeypatch):
def test_tool_mixed_invoke(monkeypatch: pytest.MonkeyPatch):
node = init_tool_node(
config={
"id": "1",

View File

@ -11,7 +11,7 @@ from libs import helper as helper_module
@pytest.mark.usefixtures("flask_app_with_containers")
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch):
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch: pytest.MonkeyPatch):
prefix = f"test_rate_limit:{uuid.uuid4().hex}"
limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60)
key = limiter._get_key("203.0.113.10")

View File

@ -6,7 +6,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models import Account, CreatorUserRole
from models.enums import ConversationFromSource, MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
@ -246,7 +246,7 @@ class TestAgentService:
tool_input=json.dumps({"test_tool": {"input": "test_input"}}),
observation=json.dumps({"test_tool": {"output": "test_output"}}),
tokens=50,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(thought1)
@ -294,7 +294,7 @@ class TestAgentService:
agent_thoughts = self._create_test_agent_thoughts(db_session_with_containers, message)
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result structure
assert result is not None
@ -370,7 +370,7 @@ class TestAgentService:
# Execute the method under test with non-existent message
with pytest.raises(ValueError, match="Message not found"):
AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4())
AgentService.get_agent_logs(app, conversation.id, fake.uuid4())
def test_get_agent_logs_with_end_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -451,7 +451,7 @@ class TestAgentService:
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -523,7 +523,7 @@ class TestAgentService:
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -561,14 +561,14 @@ class TestAgentService:
tool_input=json.dumps({"error_tool": {"input": "test_input"}}),
observation=json.dumps({"error_tool": {"output": "error_output"}}),
tokens=50,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(thought_with_error)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -592,7 +592,7 @@ class TestAgentService:
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -654,7 +654,7 @@ class TestAgentService:
# Execute the method under test
with pytest.raises(ValueError, match="App model config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
AgentService.get_agent_logs(app, conversation.id, message.id)
def test_get_agent_logs_agent_config_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -673,7 +673,7 @@ class TestAgentService:
# Execute the method under test
with pytest.raises(ValueError, match="Agent config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
AgentService.get_agent_logs(app, conversation.id, message.id)
def test_list_agent_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -687,7 +687,7 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Execute the method under test
result = AgentService.list_agent_providers(str(account.id), str(app.tenant_id))
result = AgentService.list_agent_providers(account.id, app.tenant_id)
# Verify the result
assert result is not None
@ -696,7 +696,7 @@ class TestAgentService:
# Verify the mock was called correctly
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id))
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(app.tenant_id)
def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
@ -710,7 +710,7 @@ class TestAgentService:
provider_name = "test_provider"
# Execute the method under test
result = AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
result = AgentService.get_agent_provider(account.id, app.tenant_id, provider_name)
# Verify the result
assert result is not None
@ -718,7 +718,7 @@ class TestAgentService:
# Verify the mock was called correctly
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name)
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(app.tenant_id, provider_name)
def test_get_agent_provider_plugin_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -740,7 +740,7 @@ class TestAgentService:
# Execute the method under test
with pytest.raises(ValueError, match=error_message):
AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
AgentService.get_agent_provider(account.id, app.tenant_id, provider_name)
def test_get_agent_logs_with_complex_tool_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -796,14 +796,14 @@ class TestAgentService:
{"tool1": {"output1": "result1"}, "tool2": {"output2": "result2"}, "tool3": {"output3": "result3"}}
),
tokens=100,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(complex_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -891,14 +891,14 @@ class TestAgentService:
observation=json.dumps({"file_tool": {"output": "test_output"}}),
message_files=json.dumps(["file1", "file2"]),
tokens=50,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(thought_with_files)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -926,7 +926,7 @@ class TestAgentService:
mock_external_service_dependencies["current_user"].timezone = "Asia/Shanghai"
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -960,14 +960,14 @@ class TestAgentService:
tool_input="", # Empty input
observation="", # Empty observation
tokens=50,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(empty_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result
assert result is not None
@ -1001,14 +1001,14 @@ class TestAgentService:
tool_input="invalid json", # Malformed JSON
observation="invalid json", # Malformed JSON
tokens=50,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db_session_with_containers.add(malformed_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
result = AgentService.get_agent_logs(app, conversation.id, message.id)
# Verify the result - should handle malformed JSON gracefully
assert result is not None

View File

@ -198,7 +198,7 @@ class TestAppDslService:
def test_check_version_compatibility_newer_version_returns_pending(self):
assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING
def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch):
def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0")
assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING
@ -272,7 +272,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Missing app data" in result.error
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_yaml_error_returns_failed(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
def bad_safe_load(_content: str):
raise yaml.YAMLError("bad")
@ -287,7 +289,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert result.error.startswith("Invalid YAML format:")
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_unexpected_error_returns_failed(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
monkeypatch.setattr(
AppDslService,
"_create_or_update_app",
@ -305,7 +309,9 @@ class TestAppDslService:
# ── Import: YAML URL ──────────────────────────────────────────────
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_yaml_url_fetch_error_returns_failed(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
monkeypatch.setattr(
app_dsl_service.ssrf_proxy,
"get",
@ -321,7 +327,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Error fetching YAML from URL: boom" in result.error
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_yaml_url_empty_content_returns_failed(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
response = MagicMock()
response.content = b""
response.raise_for_status.return_value = None
@ -336,7 +344,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Empty content" in result.error
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_yaml_url_file_too_large_returns_failed(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
response = MagicMock()
response.content = b"x" * (DSL_MAX_SIZE + 1)
response.raise_for_status.return_value = None
@ -379,7 +389,9 @@ class TestAppDslService:
assert result.imported_dsl_version == "99.0.0"
assert requested_urls == [yaml_url]
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch):
def test_import_app_yaml_url_github_blob_rewrites_to_raw(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
yaml_bytes = _pending_yaml_content()
@ -491,7 +503,7 @@ class TestAppDslService:
@pytest.mark.parametrize("has_workflow", [True, False])
def test_import_app_legacy_versions_extract_dependencies(
self, db_session_with_containers: Session, monkeypatch, has_workflow: bool
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch, has_workflow: bool
):
monkeypatch.setattr(
AppDslService,
@ -554,7 +566,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "expired" in result.error
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch):
def test_confirm_import_success_deletes_redis_key(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
@ -614,7 +628,9 @@ class TestAppDslService:
result = service.check_dependencies(app_model=app_model)
assert result.leaked_dependencies == []
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch):
def test_check_dependencies_calls_analysis_service(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
app_id = str(uuid4())
pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id)
redis_client.setex(
@ -665,7 +681,9 @@ class TestAppDslService:
with pytest.raises(ValueError, match="loss app mode"):
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch):
def test_create_or_update_app_existing_app_updates_fields(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
fixed_now = object()
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
@ -778,8 +796,8 @@ class TestAppDslService:
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing model_config"):
service._create_or_update_app(
app=_app_stub(mode=AppMode.CHAT.value),
data={"app": {"mode": AppMode.CHAT.value}},
app=_app_stub(mode=AppMode.CHAT),
data={"app": {"mode": AppMode.CHAT}},
account=_account_mock(),
)
@ -794,7 +812,7 @@ class TestAppDslService:
service._create_or_update_app(
app=app,
data={
"app": {"mode": AppMode.CHAT.value},
"app": {"mode": AppMode.CHAT},
"model_config": {"model": {"provider": "openai"}},
},
account=account,
@ -807,14 +825,14 @@ class TestAppDslService:
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid app mode"):
service._create_or_update_app(
app=_app_stub(mode=AppMode.RAG_PIPELINE.value),
data={"app": {"mode": AppMode.RAG_PIPELINE.value}},
app=_app_stub(mode=AppMode.RAG_PIPELINE),
data={"app": {"mode": AppMode.RAG_PIPELINE}},
account=_account_mock(),
)
# ── Export ─────────────────────────────────────────────────────────
def test_export_dsl_delegates_by_mode(self, monkeypatch):
def test_export_dsl_delegates_by_mode(self, monkeypatch: pytest.MonkeyPatch):
workflow_calls: list[bool] = []
model_calls: list[bool] = []
monkeypatch.setattr(
@ -836,14 +854,14 @@ class TestAppDslService:
assert workflow_calls == [True]
chat_app = _app_stub(
mode=AppMode.CHAT.value,
mode=AppMode.CHAT,
icon_type="emoji",
app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}),
)
AppDslService.export_dsl(chat_app)
assert model_calls == [True]
def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch):
def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
AppDslService,
"_append_workflow_export_data",
@ -1011,7 +1029,7 @@ class TestAppDslService:
# ── Workflow Export Data ───────────────────────────────────────────
def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch):
def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch: pytest.MonkeyPatch):
workflow_dict = {
"graph": {
"nodes": [
@ -1111,7 +1129,7 @@ class TestAppDslService:
assert nodes[5]["data"]["subscription_id"] == ""
assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}]
def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch):
def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch: pytest.MonkeyPatch):
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = None
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
@ -1126,7 +1144,7 @@ class TestAppDslService:
# ── Model Config Export Data ──────────────────────────────────────
def test_append_model_config_export_data_filters_credential_id(self, monkeypatch):
def test_append_model_config_export_data_filters_credential_id(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
AppDslService,
"_extract_dependencies_from_model_config",
@ -1160,7 +1178,7 @@ class TestAppDslService:
# ── Dependency Extraction ─────────────────────────────────────────
def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch):
def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.DependenciesAnalysisService,
"analyze_tool_dependency",
@ -1230,7 +1248,7 @@ class TestAppDslService:
"model:m4",
]
def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch):
def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.ToolNodeData,
"model_validate",
@ -1241,7 +1259,7 @@ class TestAppDslService:
)
assert deps == []
def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch):
def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.DependenciesAnalysisService,
"analyze_model_provider_dependency",
@ -1264,7 +1282,7 @@ class TestAppDslService:
)
assert deps == ["model:p1", "model:p2", "tool:t1"]
def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch):
def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.DependenciesAnalysisService,
"analyze_model_provider_dependency",
@ -1278,7 +1296,7 @@ class TestAppDslService:
def test_get_leaked_dependencies_empty_returns_empty(self):
assert AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, []) == []
def test_get_leaked_dependencies_delegates(self, monkeypatch):
def test_get_leaked_dependencies_delegates(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.DependenciesAnalysisService,
"get_leaked_dependencies",
@ -1289,7 +1307,7 @@ class TestAppDslService:
# ── Encryption/Decryption ─────────────────────────────────────────
def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch):
def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch: pytest.MonkeyPatch):
tenant_id = _DEFAULT_TENANT_ID
dataset_uuid = "00000000-0000-0000-0000-000000000000"
@ -1314,7 +1332,7 @@ class TestAppDslService:
value = "00000000-0000-0000-0000-000000000000"
assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id=_DEFAULT_TENANT_ID) == value
def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch):
def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.dify_config,
"DSL_EXPORT_ENCRYPT_DATASET_ID",
@ -1322,7 +1340,7 @@ class TestAppDslService:
)
assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id=_DEFAULT_TENANT_ID) is None
def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch):
def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
app_dsl_service.dify_config,
"DSL_EXPORT_ENCRYPT_DATASET_ID",

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
from models import Account
from models.enums import AppStatus, CustomizeTokenStrategy
from models.model import App, IconType, Site
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -1079,9 +1080,9 @@ class TestAppService:
site.app_id = app.id
site.code = fake.postalcode()
site.title = fake.company()
site.status = "normal"
site.status = AppStatus.NORMAL
site.default_language = "en-US"
site.customize_token_strategy = "uuid"
site.customize_token_strategy = CustomizeTokenStrategy.UUID
db_session_with_containers.add(site)
db_session_with_containers.commit()

View File

@ -10,6 +10,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import TenantAccountRole
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
@ -22,7 +23,7 @@ from services.message_service import MessageService
class ConversationServiceIntegrationTestDataFactory:
@staticmethod
def create_app_and_account(db_session_with_containers):
def create_app_and_account(db_session_with_containers: Session):
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
@ -41,7 +42,7 @@ class ConversationServiceIntegrationTestDataFactory:
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
@ -155,7 +156,7 @@ class ConversationServiceIntegrationTestDataFactory:
total_price=Decimal(0),
currency="USD",
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
invoke_from=InvokeFrom.WEB_APP,
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,

View File

@ -25,7 +25,7 @@ from services.errors.conversation import (
class ConversationServiceVariableIntegrationFactory:
@staticmethod
def create_app_and_account(db_session_with_containers):
def create_app_and_account(db_session_with_containers: Session):
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()

View File

@ -6,6 +6,7 @@ from unittest.mock import create_autospec, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -119,13 +120,13 @@ def current_user_mock():
yield current_user
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.get_document(dataset.id, None) is None
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
@ -135,7 +136,7 @@ def test_get_document_queries_by_dataset_and_document_id(db_session_with_contain
assert result.id == document.id
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
result = DocumentService.get_documents_by_ids(dataset.id, [])
@ -143,7 +144,7 @@ def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_cont
assert result == []
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
doc_b = DocumentServiceIntegrationFactory.create_document(
@ -158,13 +159,13 @@ def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers
assert {document.id for document in result} == {doc_a.id, doc_b.id}
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
paragraph_doc = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -195,7 +196,7 @@ def test_update_documents_need_summary_updates_matching_non_qa_documents(db_sess
assert refreshed_qa.need_summary is True
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -215,7 +216,7 @@ def test_get_document_download_url_uses_signed_url_helper(db_session_with_contai
get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -232,7 +233,9 @@ def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type
)
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -248,7 +251,7 @@ def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file
)
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -265,7 +268,9 @@ def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_sessio
assert result == "99"
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -278,7 +283,7 @@ def test_get_upload_file_for_upload_file_document_raises_when_file_service_retur
DocumentService._get_upload_file_for_upload_file_document(document)
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -296,7 +301,9 @@ def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session
assert result.id == upload_file.id
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
with pytest.raises(NotFound, match="Document not found"):
@ -307,7 +314,9 @@ def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_doc
)
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -329,7 +338,9 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_a
)
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -345,7 +356,9 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload
)
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -395,7 +408,7 @@ def test_prepare_document_batch_download_zip_raises_not_found_for_missing_datase
def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
db_session_with_containers,
db_session_with_containers: Session,
current_user_mock,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(
@ -418,7 +431,7 @@ def test_prepare_document_batch_download_zip_translates_permission_error_to_forb
def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
db_session_with_containers,
db_session_with_containers: Session,
current_user_mock,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(
@ -461,7 +474,7 @@ def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_o
assert download_name.endswith(".zip")
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
enabled_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -480,7 +493,9 @@ def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_co
assert [document.id for document in result] == [enabled_document.id]
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(
db_session_with_containers: Session,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
available_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -501,7 +516,7 @@ def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchive
assert [document.id for document in result] == [available_document.id]
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
error_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
@ -526,7 +541,7 @@ def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db
assert {document.id for document in result} == {error_document.id, paused_document.id}
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
batch = f"batch-{uuid4()}"
matching_document = DocumentServiceIntegrationFactory.create_document(
@ -549,7 +564,7 @@ def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_cont
assert [document.id for document in result] == [matching_document.id]
def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
def test_get_document_file_detail_returns_upload_file(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -563,7 +578,7 @@ def test_get_document_file_detail_returns_upload_file(db_session_with_containers
assert result.id == upload_file.id
def test_delete_document_emits_signal_and_commits(db_session_with_containers):
def test_delete_document_emits_signal_and_commits(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
@ -588,7 +603,7 @@ def test_delete_document_emits_signal_and_commits(db_session_with_containers):
)
def test_delete_documents_ignores_empty_input(db_session_with_containers):
def test_delete_documents_ignores_empty_input(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
@ -597,7 +612,7 @@ def test_delete_documents_ignores_empty_input(db_session_with_containers):
delay.assert_not_called()
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
db_session_with_containers.commit()
@ -637,14 +652,14 @@ def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_wi
assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
assert DocumentService.get_documents_position(dataset.id) == 4
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers: Session):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.get_documents_position(dataset.id) == 1

View File

@ -2,6 +2,7 @@ import datetime
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
@ -58,7 +59,7 @@ def _create_document(
return document
def test_build_display_status_filters_available(db_session_with_containers):
def test_build_display_status_filters_available(db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
available_doc = _create_document(
db_session_with_containers,
@ -97,7 +98,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
assert [row.id for row in rows] == [available_doc.id]
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers):
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
waiting_doc = _create_document(
db_session_with_containers,
@ -121,7 +122,7 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with
assert [row.id for row in rows] == [waiting_doc.id]
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers):
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
doc1 = _create_document(
db_session_with_containers,

View File

@ -7,6 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import TenantAccountRole
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App, DefaultEndUserSessionID, EndUser
from services.end_user_service import EndUserService
@ -16,7 +17,7 @@ class TestEndUserServiceFactory:
"""Factory class for creating test data and mock objects for end user service tests."""
@staticmethod
def create_app_and_account(db_session_with_containers):
def create_app_and_account(db_session_with_containers: Session):
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
@ -35,7 +36,7 @@ class TestEndUserServiceFactory:
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)

View File

@ -644,7 +644,7 @@ class TestFeatureService:
assert result.max_plugin_package_size == 15728640
# Verify default license status
assert result.license.status.value == "none"
assert result.license.status == "none"
assert result.license.expired_at == ""
assert result.license.workspaces.enabled is False

View File

@ -23,7 +23,7 @@ class TestFeedbackService:
"""Test FeedbackService methods."""
@pytest.fixture
def mock_db_session(self, monkeypatch):
def mock_db_session(self, monkeypatch: pytest.MonkeyPatch):
"""Mock database session."""
mock_session = mock.Mock()
monkeypatch.setattr(db, "session", mock_session)

View File

@ -122,7 +122,7 @@ class TestEmailDeliveryTestHandler:
with pytest.raises(DeliveryTestUnsupportedError):
handler.send_test(context=MagicMock(), method=MagicMock())
def test_send_test_feature_disabled(self, monkeypatch):
def test_send_test_feature_disabled(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
@ -137,7 +137,7 @@ class TestEmailDeliveryTestHandler:
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
handler.send_test(context=context, method=method)
def test_send_test_mail_not_inited(self, monkeypatch):
def test_send_test_mail_not_inited(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
@ -154,7 +154,7 @@ class TestEmailDeliveryTestHandler:
with pytest.raises(DeliveryTestError, match="Mail client is not initialized."):
handler.send_test(context=context, method=method)
def test_send_test_no_recipients(self, monkeypatch):
def test_send_test_no_recipients(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
@ -173,7 +173,7 @@ class TestEmailDeliveryTestHandler:
with pytest.raises(DeliveryTestError, match="No recipients configured"):
handler.send_test(context=context, method=method)
def test_send_test_success(self, monkeypatch):
def test_send_test_success(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
@ -209,7 +209,7 @@ class TestEmailDeliveryTestHandler:
assert kwargs["to"] == "test@example.com"
assert "RENDERED_Subj" in kwargs["subject"]
def test_send_test_sanitizes_subject(self, monkeypatch):
def test_send_test_sanitizes_subject(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import pytest
from sqlalchemy.orm import Session
from services.message_service import MessageService
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
@ -9,7 +10,7 @@ from tests.test_containers_integration_tests.helpers.execution_extra_content imp
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
def test_pagination_returns_extra_contents(db_session_with_containers):
def test_pagination_returns_extra_contents(db_session_with_containers: Session):
fixture = create_human_input_message_fixture(db_session_with_containers)
pagination = MessageService.pagination_by_first_id(

View File

@ -16,7 +16,7 @@ from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.create_segment_to_index_task import create_segment_to_index_task
@ -73,7 +73,7 @@ class TestCreateSegmentToIndexTask:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
@ -82,7 +82,7 @@ class TestCreateSegmentToIndexTask:
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
plan="basic",
)
db_session_with_containers.add(tenant)

View File

@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from tasks.document_indexing_task import (
@ -54,7 +54,7 @@ class _TrackedSessionContext:
@pytest.fixture(autouse=True)
def _ensure_testcontainers_db(db_session_with_containers):
def _ensure_testcontainers_db(db_session_with_containers: Session):
"""Ensure this suite always runs on testcontainers infrastructure."""
return db_session_with_containers
@ -121,12 +121,12 @@ class TestDatasetIndexingTaskIntegration:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
tenant = Tenant(name=fake.company(), status="normal")
tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.flush()

View File

@ -5,6 +5,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from models import TenantStatus
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
@ -55,7 +56,7 @@ class TestMailAccountDeletionTask:
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()

View File

@ -18,6 +18,7 @@ from sqlalchemy import delete
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from models import AccountStatus, TenantStatus
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_email_code_login import send_email_code_login_mail_task
@ -91,7 +92,7 @@ class TestSendEmailCodeLoginMailTask:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
@ -120,7 +121,7 @@ class TestSendEmailCodeLoginMailTask:
tenant = Tenant(
name=fake.company(),
plan="basic",
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)

View File

@ -31,7 +31,7 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
def cleanup_database(db_session_with_containers: Session):
db_session_with_containers.execute(delete(HumanInputFormRecipient))
db_session_with_containers.execute(delete(HumanInputDelivery))
db_session_with_containers.execute(delete(HumanInputForm))
@ -43,7 +43,7 @@ def cleanup_database(db_session_with_containers):
db_session_with_containers.commit()
def _create_workspace_member(db_session_with_containers):
def _create_workspace_member(db_session_with_containers: Session):
account = Account(
email="owner@example.com",
name="Owner",

View File

@ -21,7 +21,7 @@ from tasks.remove_app_and_related_data_task import (
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
def cleanup_database(db_session_with_containers: Session):
db_session_with_containers.execute(delete(WorkflowDraftVariable))
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
db_session_with_containers.execute(delete(UploadFile))
@ -30,7 +30,7 @@ def cleanup_database(db_session_with_containers):
db_session_with_containers.commit()
def _create_tenant_and_app(db_session_with_containers):
def _create_tenant_and_app(db_session_with_containers: Session):
tenant = Tenant(name=f"test_tenant_{uuid.uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()

View File

@ -57,7 +57,7 @@ class TestGuessFileInfoFromResponse:
(False, "bin"),
],
)
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
def test_generated_filename_when_missing(self, monkeypatch: pytest.MonkeyPatch, magic_available, expected_ext):
if magic_available:
if helpers.magic is None:
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
@ -155,7 +155,7 @@ class TestMagicImportWarnings:
)
def test_magic_import_warning_per_platform(
self,
monkeypatch,
monkeypatch: pytest.MonkeyPatch,
platform_name,
expected_message,
):

View File

@ -101,7 +101,7 @@ def test_register_schema_models_registers_multiple_models():
assert called_names == ["UserModel", "ProductModel"]
def test_register_schema_models_calls_register_schema_model(monkeypatch):
def test_register_schema_models_calls_register_schema_model(monkeypatch: pytest.MonkeyPatch):
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)

View File

@ -68,7 +68,7 @@ def _segment():
)
def test_get_segment_with_summary(monkeypatch):
def test_get_segment_with_summary(monkeypatch: pytest.MonkeyPatch):
segment = _segment()
summary = SimpleNamespace(summary_content="summary")

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -35,7 +36,7 @@ def dataset():
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
def bypass_decorators(mocker: MockerFixture):
"""Bypass all decorators on the API method."""
mocker.patch(
"controllers.console.datasets.hit_testing.setup_required",
@ -56,7 +57,7 @@ def bypass_decorators(mocker):
class TestHitTestingApi:
def test_hit_testing_success(self, app, dataset, dataset_id):
def test_hit_testing_success(self, app: Flask, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
@ -99,7 +100,7 @@ class TestHitTestingApi:
assert "records" in result
assert result["records"] == []
def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id):
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
@ -150,7 +151,7 @@ class TestHitTestingApi:
assert result["query"] == payload["query"]
assert result["records"] == records
def test_hit_testing_dataset_not_found(self, app, dataset_id):
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
@ -175,7 +176,7 @@ class TestHitTestingApi:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -60,7 +61,7 @@ def metadata_id():
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
def bypass_decorators(mocker: MockerFixture):
"""Bypass setup/login/license decorators."""
mocker.patch(
"controllers.console.datasets.metadata.setup_required",

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock, PropertyMock, patch
import pytest
from flask import Flask
from pytest_mock import MockerFixture
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
@ -31,7 +32,7 @@ def app():
@pytest.fixture(autouse=True)
def bypass_auth_and_setup(mocker):
def bypass_auth_and_setup(mocker: MockerFixture):
"""Bypass setup/login/account decorators."""
mocker.patch(
"controllers.console.datasets.website.login_required",
@ -48,7 +49,7 @@ def bypass_auth_and_setup(mocker):
class TestWebsiteCrawlApi:
def test_crawl_success(self, app, mocker):
def test_crawl_success(self, app, mocker: MockerFixture):
api = WebsiteCrawlApi()
method = unwrap(api.post)
@ -85,7 +86,7 @@ class TestWebsiteCrawlApi:
assert status == 200
assert result["job_id"] == "job-1"
def test_crawl_invalid_payload(self, app, mocker):
def test_crawl_invalid_payload(self, app, mocker: MockerFixture):
api = WebsiteCrawlApi()
method = unwrap(api.post)
@ -113,7 +114,7 @@ class TestWebsiteCrawlApi:
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
method(api)
def test_crawl_service_error(self, app, mocker):
def test_crawl_service_error(self, app, mocker: MockerFixture):
api = WebsiteCrawlApi()
method = unwrap(api.post)
@ -150,7 +151,7 @@ class TestWebsiteCrawlApi:
class TestWebsiteCrawlStatusApi:
def test_get_status_success(self, app, mocker):
def test_get_status_success(self, app, mocker: MockerFixture):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
@ -181,7 +182,7 @@ class TestWebsiteCrawlStatusApi:
assert status == 200
assert result["status"] == "completed"
def test_get_status_invalid_provider(self, app, mocker):
def test_get_status_invalid_provider(self, app, mocker: MockerFixture):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
@ -203,7 +204,7 @@ class TestWebsiteCrawlStatusApi:
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
method(api, job_id)
def test_get_status_service_error(self, app, mocker):
def test_get_status_service_error(self, app, mocker: MockerFixture):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from controllers.console.datasets.error import PipelineNotFoundError
from controllers.console.datasets.wraps import get_rag_pipeline
@ -16,7 +17,7 @@ class TestGetRagPipeline:
with pytest.raises(ValueError, match="missing pipeline_id"):
dummy_view()
def test_pipeline_not_found(self, mocker):
def test_pipeline_not_found(self, mocker: MockerFixture):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
@ -34,7 +35,7 @@ class TestGetRagPipeline:
with pytest.raises(PipelineNotFoundError):
dummy_view(pipeline_id="pipeline-1")
def test_pipeline_found_and_injected(self, mocker):
def test_pipeline_found_and_injected(self, mocker: MockerFixture):
pipeline = Mock(spec=Pipeline)
pipeline.id = "pipeline-1"
pipeline.tenant_id = "tenant-1"
@ -57,7 +58,7 @@ class TestGetRagPipeline:
assert result is pipeline
def test_pipeline_id_removed_from_kwargs(self, mocker):
def test_pipeline_id_removed_from_kwargs(self, mocker: MockerFixture):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
@ -79,7 +80,7 @@ class TestGetRagPipeline:
assert result == "ok"
def test_pipeline_id_cast_to_string(self, mocker):
def test_pipeline_id_cast_to_string(self, mocker: MockerFixture):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline

View File

@ -4,6 +4,7 @@ import uuid
from unittest.mock import Mock, PropertyMock, patch
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import (
@ -18,7 +19,7 @@ from models.model import App, InstalledApp, RecommendedApp
@pytest.fixture(autouse=True)
def bypass_only_edition_cloud(mocker):
def bypass_only_edition_cloud(mocker: MockerFixture):
"""
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
"""
@ -29,7 +30,7 @@ def bypass_only_edition_cloud(mocker):
@pytest.fixture
def mock_admin_auth(mocker):
def mock_admin_auth(mocker: MockerFixture):
"""
Provide valid admin authentication for controller tests.
"""
@ -44,7 +45,7 @@ def mock_admin_auth(mocker):
@pytest.fixture
def mock_console_payload(mocker):
def mock_console_payload(mocker: MockerFixture):
payload = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
@ -62,7 +63,7 @@ def mock_console_payload(mocker):
@pytest.fixture
def mock_banner_payload(mocker):
def mock_banner_payload(mocker: MockerFixture):
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
@ -78,7 +79,7 @@ def mock_banner_payload(mocker):
@pytest.fixture
def mock_session_factory(mocker):
def mock_session_factory(mocker: MockerFixture):
mock_session = Mock()
mock_session.execute = Mock()
mock_session.add = Mock()
@ -97,7 +98,7 @@ class TestDeleteExploreBannerApi:
def setup_method(self):
self.api = DeleteExploreBannerApi()
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
def test_delete_banner_not_found(self, mocker: MockerFixture, mock_admin_auth):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
@ -106,7 +107,7 @@ class TestDeleteExploreBannerApi:
with pytest.raises(NotFound, match="is not found"):
self.api.delete(uuid.uuid4())
def test_delete_banner_success(self, mocker, mock_admin_auth):
def test_delete_banner_success(self, mocker: MockerFixture, mock_admin_auth):
mock_banner = Mock()
mocker.patch(
@ -126,7 +127,7 @@ class TestInsertExploreBannerApi:
def setup_method(self):
self.api = InsertExploreBannerApi()
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
def test_insert_banner_success(self, mocker: MockerFixture, mock_admin_auth, mock_banner_payload):
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
@ -168,7 +169,7 @@ class TestInsertExploreAppApiDelete:
def setup_method(self):
self.api = InsertExploreAppApi()
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
def test_delete_when_not_in_explore(self, mocker: MockerFixture, mock_admin_auth):
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
@ -183,7 +184,7 @@ class TestInsertExploreAppApiDelete:
assert status == 204
assert response["result"] == "success"
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
def test_delete_when_in_explore_with_trial_app(self, mocker: MockerFixture, mock_admin_auth):
"""Test deleting an app from explore that has a trial app."""
app_id = uuid.uuid4()
@ -225,7 +226,7 @@ class TestInsertExploreAppApiDelete:
assert response["result"] == "success"
assert mock_app.is_public is False
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
def test_delete_with_installed_apps(self, mocker: MockerFixture, mock_admin_auth):
"""Test deleting an app that has installed apps in other tenants."""
app_id = uuid.uuid4()
@ -270,7 +271,7 @@ class TestInsertExploreAppListApi:
def setup_method(self):
self.api = InsertExploreAppListApi()
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
def test_app_not_found(self, mocker: MockerFixture, mock_admin_auth, mock_console_payload):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
@ -281,7 +282,7 @@ class TestInsertExploreAppListApi:
def test_create_recommended_app(
self,
mocker,
mocker: MockerFixture,
mock_admin_auth,
mock_console_payload,
):
@ -318,7 +319,9 @@ class TestInsertExploreAppListApi:
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
def test_update_recommended_app(
self, mocker: MockerFixture, mock_admin_auth, mock_console_payload, mock_session_factory
):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
@ -344,7 +347,7 @@ class TestInsertExploreAppListApi:
def test_site_data_overrides_payload(
self,
mocker,
mocker: MockerFixture,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
@ -381,7 +384,7 @@ class TestInsertExploreAppListApi:
def test_create_trial_app_when_can_trial_enabled(
self,
mocker,
mocker: MockerFixture,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
@ -413,7 +416,7 @@ class TestInsertExploreAppListApi:
def test_update_recommended_app_with_trial(
self,
mocker,
mocker: MockerFixture,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
@ -450,7 +453,7 @@ class TestInsertExploreAppListApi:
def test_update_recommended_app_without_trial(
self,
mocker,
mocker: MockerFixture,
mock_admin_auth,
mock_console_payload,
mock_session_factory,

View File

@ -1,3 +1,4 @@
from pytest_mock import MockerFixture
from werkzeug.exceptions import Unauthorized
@ -11,7 +12,7 @@ def unwrap(func):
class TestFeatureApi:
def test_get_tenant_features_success(self, mocker):
def test_get_tenant_features_success(self, mocker: MockerFixture):
from controllers.console.feature import FeatureApi
mocker.patch(
@ -32,7 +33,7 @@ class TestFeatureApi:
class TestSystemFeatureApi:
def test_get_system_features_authenticated(self, mocker):
def test_get_system_features_authenticated(self, mocker: MockerFixture):
"""
current_user.is_authenticated == True
"""
@ -56,7 +57,7 @@ class TestSystemFeatureApi:
assert result == {"features": {"sys_feature": True}}
def test_get_system_features_unauthenticated(self, mocker):
def test_get_system_features_unauthenticated(self, mocker: MockerFixture):
"""
current_user.is_authenticated raises Unauthorized
"""

View File

@ -32,7 +32,7 @@ class TestDefaultModelApi:
with (
app.test_request_context(
"/",
query_string={"model_type": ModelType.LLM.value},
query_string={"model_type": ModelType.LLM},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
@ -53,7 +53,7 @@ class TestDefaultModelApi:
payload = {
"model_settings": [
{
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"provider": "openai",
"model": "gpt-4",
}
@ -77,7 +77,7 @@ class TestDefaultModelApi:
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
app.test_request_context("/", query_string={"model_type": ModelType.LLM}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
@ -113,7 +113,7 @@ class TestModelProviderModelApi:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"load_balancing": {
"configs": [{"weight": 1}],
"enabled": True,
@ -139,7 +139,7 @@ class TestModelProviderModelApi:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
}
with (
@ -180,7 +180,7 @@ class TestModelProviderModelCredentialApi:
"/",
query_string={
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
},
),
patch(
@ -208,7 +208,7 @@ class TestModelProviderModelCredentialApi:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"credentials": {"key": "val"},
}
@ -229,7 +229,7 @@ class TestModelProviderModelCredentialApi:
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
@ -248,7 +248,7 @@ class TestModelProviderModelCredentialApi:
payload = {
"model": "gpt",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
}
@ -269,7 +269,7 @@ class TestModelProviderModelCredentialSwitchApi:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"credential_id": "abc",
}
@ -293,7 +293,7 @@ class TestModelEnableDisableApis:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
}
with (
@ -314,7 +314,7 @@ class TestModelEnableDisableApis:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
}
with (
@ -337,7 +337,7 @@ class TestModelProviderModelValidateApi:
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"credentials": {"key": "val"},
}
@ -360,7 +360,7 @@ class TestModelProviderModelValidateApi:
payload = {
"model": model_name,
"model_type": ModelType.LLM.value,
"model_type": ModelType.LLM,
"credentials": {},
}
@ -412,7 +412,7 @@ class TestParameterAndAvailableModels:
):
service_mock.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
result = method(api, ModelType.LLM)
assert "data" in result
@ -442,6 +442,6 @@ class TestParameterAndAvailableModels:
):
service.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
result = method(api, ModelType.LLM)
assert result["data"] == []

View File

@ -189,7 +189,7 @@ class TestGetUserTenant:
"""Test get_user_tenant decorator"""
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch: pytest.MonkeyPatch):
"""Test that decorator injects tenant_model and user_model into kwargs"""
# Arrange
@ -244,7 +244,9 @@ class TestGetUserTenant:
protected_view()
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
def test_should_use_default_session_id_when_user_id_empty(
self, mock_tenant_class, app: Flask, monkeypatch: pytest.MonkeyPatch
):
"""Test that default session ID is used when user_id is empty string"""
# Arrange

View File

@ -340,7 +340,7 @@ class TestConversationAppModeValidation:
@pytest.mark.parametrize(
"mode",
[
AppMode.CHAT.value,
AppMode.CHAT,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT.value,
],
@ -365,7 +365,7 @@ class TestConversationAppModeValidation:
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value
app.mode = AppMode.COMPLETION
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
@ -498,7 +498,7 @@ class TestConversationApiController:
def test_list_not_chat(self, app) -> None:
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
with app.test_request_context("/conversations", method="GET"):
@ -531,7 +531,7 @@ class TestConversationApiController:
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -546,7 +546,7 @@ class TestConversationDetailApiController:
def test_delete_not_chat(self, app) -> None:
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
@ -562,7 +562,7 @@ class TestConversationDetailApiController:
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
@ -580,7 +580,7 @@ class TestConversationRenameApiController:
api = ConversationRenameApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -596,7 +596,7 @@ class TestConversationVariablesApiController:
def test_not_chat(self, app) -> None:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1/variables", method="GET"):
@ -612,7 +612,7 @@ class TestConversationVariablesApiController:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -645,7 +645,7 @@ class TestConversationVariablesApiController:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -671,7 +671,7 @@ class TestConversationVariableDetailApiController:
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -697,7 +697,7 @@ class TestConversationVariableDetailApiController:
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(
@ -731,7 +731,7 @@ class TestConversationVariableDetailApiController:
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
with app.test_request_context(

View File

@ -3,6 +3,7 @@ from unittest.mock import Mock
from uuid import UUID, uuid4
import pytest
from pytest_mock import MockerFixture
from controllers.service_api.end_user.end_user import EndUserApi
from controllers.service_api.end_user.error import EndUserNotFoundError
@ -21,7 +22,9 @@ class TestEndUserApi:
app.tenant_id = str(uuid4())
return app
def test_get_end_user_returns_all_attributes(self, mocker, resource: EndUserApi, app_model: App) -> None:
def test_get_end_user_returns_all_attributes(
self, mocker: MockerFixture, resource: EndUserApi, app_model: App
) -> None:
end_user = Mock(spec=EndUser)
end_user.id = str(uuid4())
end_user.tenant_id = app_model.tenant_id
@ -54,7 +57,7 @@ class TestEndUserApi:
assert result["created_at"].startswith("2024-01-01T00:00:00")
assert result["updated_at"].startswith("2024-01-02T00:00:00")
def test_get_end_user_not_found(self, mocker, resource: EndUserApi, app_model: App) -> None:
def test_get_end_user_not_found(self, mocker: MockerFixture, resource: EndUserApi, app_model: App) -> None:
mocker.patch("controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=None)
with pytest.raises(EndUserNotFoundError):

View File

@ -12,12 +12,13 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
@pytest.fixture
def mock_action_class(mocker):
def mock_action_class(mocker: MockerFixture):
mock_action = MagicMock()
mocker.patch(
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",

View File

@ -3,6 +3,7 @@
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.agent.strategy.plugin import PluginAgentStrategy
@ -213,7 +214,9 @@ class TestInvoke:
(None, None, "msg"),
],
)
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
def test_invoke_optional_arguments(
self, strategy, mocker: MockerFixture, conversation_id, app_id, message_id
) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))

View File

@ -3,6 +3,7 @@ from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
import core.agent.base_agent_runner as module
from core.agent.base_agent_runner import BaseAgentRunner
@ -13,7 +14,7 @@ from core.agent.base_agent_runner import BaseAgentRunner
@pytest.fixture
def mock_db_session(mocker):
def mock_db_session(mocker: MockerFixture):
session = mocker.MagicMock()
mocker.patch.object(module.db, "session", session)
return session
@ -41,13 +42,13 @@ def runner(mocker, mock_db_session):
class TestRepack:
def test_sets_empty_if_none(self, runner, mocker):
def test_sets_empty_if_none(self, runner, mocker: MockerFixture):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = None
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == ""
def test_keeps_existing(self, runner, mocker):
def test_keeps_existing(self, runner, mocker: MockerFixture):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = "abc"
result = runner._repack_app_generate_entity(entity)
@ -60,7 +61,7 @@ class TestRepack:
class TestUpdatePromptTool:
def build_param(self, mocker, **kwargs):
def build_param(self, mocker: MockerFixture, **kwargs):
p = mocker.MagicMock()
p.form = kwargs.get("form")
@ -75,7 +76,7 @@ class TestUpdatePromptTool:
p.required = kwargs.get("required", False)
return p
def test_skip_non_llm(self, runner, mocker):
def test_skip_non_llm(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = self.build_param(mocker, form="NOT_LLM")
tool.get_runtime_parameters.return_value = [param]
@ -86,7 +87,7 @@ class TestUpdatePromptTool:
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_enum_and_required(self, runner, mocker):
def test_enum_and_required(self, runner, mocker: MockerFixture):
option = mocker.MagicMock(value="opt1")
param = self.build_param(
mocker,
@ -104,7 +105,7 @@ class TestUpdatePromptTool:
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert "p1" in result.parameters["required"]
def test_skip_file_type_param(self, runner, mocker):
def test_skip_file_type_param(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
param.type = module.ToolParameter.ToolParameterType.FILE
@ -116,7 +117,7 @@ class TestUpdatePromptTool:
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_duplicate_required_not_duplicated(self, runner, mocker):
def test_duplicate_required_not_duplicated(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = self.build_param(
@ -141,7 +142,7 @@ class TestUpdatePromptTool:
class TestCreateAgentThought:
def test_with_files(self, runner, mock_db_session, mocker):
def test_with_files(self, runner, mock_db_session, mocker: MockerFixture):
mock_thought = mocker.MagicMock(id=10)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
@ -149,7 +150,7 @@ class TestCreateAgentThought:
assert result == "10"
assert runner.agent_thought_count == 1
def test_without_files(self, runner, mock_db_session, mocker):
def test_without_files(self, runner, mock_db_session, mocker: MockerFixture):
mock_thought = mocker.MagicMock(id=11)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
@ -163,7 +164,7 @@ class TestCreateAgentThought:
class TestSaveAgentThought:
def setup_agent(self, mocker):
def setup_agent(self, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1;tool2"
agent.tool_labels = {}
@ -175,7 +176,7 @@ class TestSaveAgentThought:
with pytest.raises(ValueError):
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
def test_full_update(self, runner, mock_db_session, mocker):
def test_full_update(self, runner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -210,7 +211,7 @@ class TestSaveAgentThought:
assert agent.tokens == 3
assert "tool1" in json.loads(agent.tool_labels_str)
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
def test_label_fallback_when_none(self, runner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
agent.tool = "unknown_tool"
mock_db_session.scalar.return_value = agent
@ -220,7 +221,7 @@ class TestSaveAgentThought:
labels = json.loads(agent.tool_labels_str)
assert "unknown_tool" in labels
def test_json_failure_paths(self, runner, mock_db_session, mocker):
def test_json_failure_paths(self, runner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -241,13 +242,13 @@ class TestSaveAgentThought:
assert mock_db_session.commit.called
def test_messages_ids_none(self, runner, mock_db_session, mocker):
def test_messages_ids_none(self, runner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
assert mock_db_session.commit.called
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
def test_success_dict_serialization(self, runner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -273,19 +274,19 @@ class TestSaveAgentThought:
class TestOrganizeUserPrompt:
def test_no_files(self, runner, mock_db_session, mocker):
def test_no_files(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = []
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_with_files_no_config(self, runner, mock_db_session, mocker):
def test_with_files_no_config(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
@ -305,27 +306,27 @@ class TestOrganizeUserPrompt:
class TestOrganizeHistory:
def test_empty(self, runner, mock_db_session, mocker):
def test_empty(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
result = runner.organize_agent_history([])
assert result == []
def test_with_answer_only(self, runner, mock_db_session, mocker):
def test_with_answer_only(self, runner, mock_db_session, mocker: MockerFixture):
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
def test_skip_current_message(self, runner, mock_db_session, mocker):
def test_skip_current_message(self, runner, mock_db_session, mocker: MockerFixture):
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert result == []
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1",
tool_input="invalid",
@ -341,7 +342,7 @@ class TestOrganizeHistory:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
def test_empty_tool_name_split(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(tool=";", thought="thinking")
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
@ -350,7 +351,7 @@ class TestOrganizeHistory:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1",
tool_input=json.dumps({"tool1": {"x": 1}}),
@ -379,7 +380,7 @@ class TestOrganizeHistory:
class TestConvertToolToPromptMessageTool:
def test_basic_conversion(self, runner, mocker):
def test_basic_conversion(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
runtime_param = mocker.MagicMock()
@ -404,7 +405,7 @@ class TestConvertToolToPromptMessageTool:
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
def test_full_conversion_multiple_params(self, runner, mocker):
def test_full_conversion_multiple_params(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
# LLM param with input_schema override
@ -441,7 +442,7 @@ class TestConvertToolToPromptMessageTool:
class TestInitPromptToolsExtended:
def test_agent_tool_branch(self, runner, mocker):
def test_agent_tool_branch(self, runner, mocker: MockerFixture):
agent_tool = mocker.MagicMock(tool_name="agent_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
@ -449,7 +450,7 @@ class TestInitPromptToolsExtended:
tools, prompts = runner._init_prompt_tools()
assert "agent_tool" in tools
def test_exception_in_conversion(self, runner, mocker):
def test_exception_in_conversion(self, runner, mocker: MockerFixture):
agent_tool = mocker.MagicMock(tool_name="bad_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
@ -464,7 +465,7 @@ class TestInitPromptToolsExtended:
class TestAdditionalCoverage:
def test_update_prompt_with_input_schema(self, runner, mocker):
def test_update_prompt_with_input_schema(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = mocker.MagicMock()
@ -487,7 +488,7 @@ class TestAdditionalCoverage:
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["p1"]["type"] == "number"
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {"tool1": {"en_US": "existing"}}
@ -498,7 +499,7 @@ class TestAdditionalCoverage:
labels = json.loads(agent.tool_labels_str)
assert labels["tool1"]["en_US"] == "existing"
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
@ -508,7 +509,7 @@ class TestAdditionalCoverage:
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
assert agent.tool_meta_str == "meta_string"
def test_convert_dataset_retriever_tool(self, runner, mocker):
def test_convert_dataset_retriever_tool(self, runner, mocker: MockerFixture):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
@ -525,7 +526,7 @@ class TestAdditionalCoverage:
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
@ -544,7 +545,7 @@ class TestAdditionalCoverage:
result = runner.organize_agent_user_prompt(msg)
assert result is not None
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(tool=None, thought="thinking")
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
@ -554,7 +555,7 @@ class TestAdditionalCoverage:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1;tool2",
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
@ -572,7 +573,7 @@ class TestAdditionalCoverage:
# ================= Additional Surgical Coverage =================
def test_convert_tool_select_enum_branch(self, runner, mocker):
def test_convert_tool_select_enum_branch(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
@ -599,7 +600,7 @@ class TestAdditionalCoverage:
class TestConvertDatasetRetrieverTool:
def test_required_param_added(self, runner, mocker):
def test_required_param_added(self, runner, mocker: MockerFixture):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
@ -619,7 +620,7 @@ class TestConvertDatasetRetrieverTool:
class TestBaseAgentRunnerInit:
def test_init_sets_stream_tool_call_and_files(self, mocker):
def test_init_sets_stream_tool_call_and_files(self, mocker: MockerFixture):
session = mocker.MagicMock()
session.scalar.return_value = 2
mocker.patch.object(module.db, "session", session)
@ -662,7 +663,7 @@ class TestBaseAgentRunnerInit:
class TestBaseAgentRunnerCoverage:
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
def test_convert_tool_skips_non_llm_param(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
@ -680,7 +681,7 @@ class TestBaseAgentRunnerCoverage:
assert prompt_tool.parameters["properties"] == {}
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture):
dataset_tool = mocker.MagicMock()
dataset_tool.entity.identity.name = "ds"
runner.dataset_tools = [dataset_tool]
@ -692,7 +693,7 @@ class TestBaseAgentRunnerCoverage:
assert tools["ds"] == dataset_tool
assert len(prompt_tools) == 1
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
def test_update_prompt_message_tool_select_enum(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
option1 = mocker.MagicMock(value="A")
@ -716,7 +717,7 @@ class TestBaseAgentRunnerCoverage:
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
@ -754,7 +755,7 @@ class TestBaseAgentRunnerCoverage:
assert isinstance(agent.observation, str)
assert isinstance(agent.tool_meta_str, str)
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1;;"
agent.tool_labels = {}
@ -768,7 +769,7 @@ class TestBaseAgentRunnerCoverage:
labels = json.loads(agent.tool_labels_str)
assert "" not in labels
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker: MockerFixture):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
@ -778,7 +779,7 @@ class TestBaseAgentRunnerCoverage:
assert system_message in result
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1",
tool_input=None,

View File

@ -2,6 +2,7 @@ import json
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.entities import AgentScratchpadUnit
@ -25,7 +26,7 @@ class DummyRunner(CotAgentRunner):
@pytest.fixture
def runner(mocker):
def runner(mocker: MockerFixture):
# Prevent BaseAgentRunner __init__ from hitting database
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
@ -165,7 +166,7 @@ class TestHandleInvokeAction:
response, meta = runner._handle_invoke_action(action, {}, [])
assert "there is not a tool named" in response
def test_tool_with_json_string_args(self, runner, mocker):
def test_tool_with_json_string_args(self, runner, mocker: MockerFixture):
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
@ -180,7 +181,7 @@ class TestHandleInvokeAction:
class TestOrganizeHistoricPromptMessages:
def test_empty_history(self, runner, mocker):
def test_empty_history(self, runner, mocker: MockerFixture):
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
return_value=[],
@ -190,7 +191,7 @@ class TestOrganizeHistoricPromptMessages:
class TestRun:
def test_run_handles_empty_parser_output(self, runner, mocker):
def test_run_handles_empty_parser_output(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -202,7 +203,7 @@ class TestRun:
results = list(runner.run(message, "query", {}))
assert isinstance(results, list)
def test_run_with_action_and_tool_invocation(self, runner, mocker):
def test_run_with_action_and_tool_invocation(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -223,7 +224,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_respects_max_iteration_boundary(self, runner, mocker):
def test_run_respects_max_iteration_boundary(self, runner, mocker: MockerFixture):
runner.app_config.agent.max_iteration = 1
message = MagicMock()
message.id = "msg-id"
@ -245,7 +246,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_basic_flow(self, runner, mocker):
def test_run_basic_flow(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -257,7 +258,7 @@ class TestRun:
results = list(runner.run(message, "query", {"name": "John"}))
assert results
def test_run_max_iteration_error(self, runner, mocker):
def test_run_max_iteration_error(self, runner, mocker: MockerFixture):
runner.app_config.agent.max_iteration = 0
message = MagicMock()
message.id = "msg-id"
@ -272,7 +273,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {}))
def test_run_increase_usage_aggregation(self, runner, mocker):
def test_run_increase_usage_aggregation(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
runner.app_config.agent.max_iteration = 2
@ -329,7 +330,7 @@ class TestRun:
assert final_usage.completion_price == 2
assert final_usage.total_price == 4
def test_run_when_no_action_branch(self, runner, mocker):
def test_run_when_no_action_branch(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -341,7 +342,7 @@ class TestRun:
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == ""
def test_run_usage_missing_key_branch(self, runner, mocker):
def test_run_usage_missing_key_branch(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -354,7 +355,7 @@ class TestRun:
list(runner.run(message, "query", {}))
def test_run_prompt_tool_update_branch(self, runner, mocker):
def test_run_prompt_tool_update_branch(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -410,7 +411,7 @@ class TestRun:
class TestInitReactState:
def test_init_react_state_resets_state(self, runner, mocker):
def test_init_react_state_resets_state(self, runner, mocker: MockerFixture):
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
runner._agent_scratchpad = ["old"]
runner._query = "old"
@ -423,7 +424,7 @@ class TestInitReactState:
class TestHandleInvokeActionExtended:
def test_tool_with_invalid_json_string_args(self, runner, mocker):
def test_tool_with_invalid_json_string_args(self, runner, mocker: MockerFixture):
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
@ -457,7 +458,7 @@ class TestFillInputsEdgeCases:
class TestOrganizeHistoricPromptMessagesExtended:
def test_user_message_flushes_scratchpad(self, runner, mocker):
def test_user_message_flushes_scratchpad(self, runner, mocker: MockerFixture):
from graphon.model_runtime.entities.message_entities import UserPromptMessage
user_message = UserPromptMessage(content="Hi")
@ -480,7 +481,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
with pytest.raises(NotImplementedError):
runner._organize_historic_prompt_messages([])
def test_agent_history_transform_invocation(self, runner, mocker):
def test_agent_history_transform_invocation(self, runner, mocker: MockerFixture):
mock_transform = MagicMock()
mock_transform.get_prompt.return_value = []
@ -495,7 +496,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
class TestRunAdditionalBranches:
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
def test_run_with_no_action_final_answer_empty(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -507,7 +508,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert any(hasattr(r, "delta") for r in results)
def test_run_with_final_answer_action_string(self, runner, mocker):
def test_run_with_final_answer_action_string(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -521,7 +522,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "done"
def test_run_with_final_answer_action_dict(self, runner, mocker):
def test_run_with_final_answer_action_dict(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -535,7 +536,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert json.loads(results[-1].delta.message.content) == {"a": 1}
def test_run_with_string_final_answer(self, runner, mocker):
def test_run_with_string_final_answer(self, runner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from pytest_mock import MockerFixture
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
@ -55,7 +56,7 @@ def runner():
class TestOrganizeSystemPrompt:
def test_organize_system_prompt_success(self, runner, mocker):
def test_organize_system_prompt_success(self, runner, mocker: MockerFixture):
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
@ -154,7 +155,7 @@ class TestOrganizeUserQuery:
class TestOrganizePromptMessages:
def test_no_scratchpad(self, runner, mocker):
def test_no_scratchpad(self, runner, mocker: MockerFixture):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
@ -164,7 +165,7 @@ class TestOrganizePromptMessages:
assert "query" in result
runner._organize_historic_prompt_messages.assert_called_once()
def test_with_final_scratchpad(self, runner, mocker):
def test_with_final_scratchpad(self, runner, mocker: MockerFixture):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
@ -177,7 +178,7 @@ class TestOrganizePromptMessages:
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Final Answer: done" in combined
def test_with_thought_action_observation(self, runner, mocker):
def test_with_thought_action_observation(self, runner, mocker: MockerFixture):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
@ -197,7 +198,7 @@ class TestOrganizePromptMessages:
assert "Action: action" in combined
assert "Observation: observe" in combined
def test_multiple_units_mixed(self, runner, mocker):
def test_multiple_units_mixed(self, runner, mocker: MockerFixture):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])

View File

@ -1,6 +1,7 @@
import json
import pytest
from pytest_mock import MockerFixture
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from graphon.model_runtime.entities.message_entities import (
@ -74,7 +75,7 @@ class TestOrganizeInstructionPrompt:
class TestOrganizeHistoricPrompt:
def test_with_user_and_assistant_string(self, runner, mocker):
def test_with_user_and_assistant_string(self, runner, mocker: MockerFixture):
user_msg = UserPromptMessage(content="Hello")
assistant_msg = AssistantPromptMessage(content="Hi there")
@ -89,7 +90,7 @@ class TestOrganizeHistoricPrompt:
assert "Question: Hello" in result
assert "Hi there" in result
def test_assistant_list_with_text_content(self, runner, mocker):
def test_assistant_list_with_text_content(self, runner, mocker: MockerFixture):
text_content = TextPromptMessageContent(data="Partial answer")
assistant_msg = AssistantPromptMessage(content=[text_content])
@ -103,7 +104,7 @@ class TestOrganizeHistoricPrompt:
assert "Partial answer" in result
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker: MockerFixture):
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
assistant_msg = AssistantPromptMessage(content=[non_text_content])
@ -116,7 +117,7 @@ class TestOrganizeHistoricPrompt:
result = runner._organize_historic_prompt()
assert result == ""
def test_empty_history(self, runner, mocker):
def test_empty_history(self, runner, mocker: MockerFixture):
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
@ -136,7 +137,7 @@ class TestOrganizePromptMessages:
def test_full_flow_with_scratchpad(
self,
runner,
mocker,
mocker: MockerFixture,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
@ -171,7 +172,12 @@ class TestOrganizePromptMessages:
assert "Question: What is Python?" in content
def test_no_scratchpad(
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
self,
runner,
mocker: MockerFixture,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
@ -198,7 +204,7 @@ class TestOrganizePromptMessages:
def test_partial_scratchpad_units(
self,
runner,
mocker,
mocker: MockerFixture,
thought,
action,
observation,

View File

@ -3,6 +3,7 @@ from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.agent.errors import AgentMaxIterationError
from core.agent.fc_agent_runner import FunctionCallAgentRunner
@ -68,7 +69,7 @@ class DummyResult:
@pytest.fixture
def runner(mocker):
def runner(mocker: MockerFixture):
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
@ -230,7 +231,7 @@ class TestOrganizeUserQuery:
result = runner._organize_user_query(None, [])
assert len(result) == 1
def test_with_files_uses_image_detail_config(self, runner, mocker):
def test_with_files_uses_image_detail_config(self, runner, mocker: MockerFixture):
file_content = TextPromptMessageContent(data="file-content")
mock_to_prompt = mocker.patch(
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
@ -352,7 +353,7 @@ class TestRunMethod:
assert len(outputs) == 1
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker: MockerFixture):
message = MagicMock(id="m1")
runner.stream_tool_call = True
@ -398,7 +399,7 @@ class TestRunMethod:
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_with_tool_instance_and_files(self, runner, mocker):
def test_run_with_tool_instance_and_files(self, runner, mocker: MockerFixture):
message = MagicMock(id="m1")
tool_call = MagicMock()

View File

@ -9,6 +9,7 @@ mocking; ensure entity invariants and validation rules remain stable.
import pytest
from pydantic import ValidationError
from pytest_mock import MockerFixture
from core.agent.plugin_entities import (
AgentFeature,
@ -28,12 +29,12 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
@pytest.fixture
def mock_identity(mocker):
def mock_identity(mocker: MockerFixture):
return mocker.MagicMock(spec=AgentStrategyIdentity)
@pytest.fixture
def mock_provider_identity(mocker):
def mock_provider_identity(mocker: MockerFixture):
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
@ -47,7 +48,7 @@ class TestAgentStrategyParameterType:
"enum_member",
list(AgentStrategyParameter.AgentStrategyParameterType),
)
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
def test_as_normal_type_calls_external_function(self, mocker: MockerFixture, enum_member) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.as_normal_type",
return_value="normalized",
@ -58,7 +59,7 @@ class TestAgentStrategyParameterType:
mock_func.assert_called_once_with(enum_member)
assert result == "normalized"
def test_as_normal_type_propagates_exception(self, mocker) -> None:
def test_as_normal_type_propagates_exception(self, mocker: MockerFixture) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.as_normal_type",
@ -79,7 +80,7 @@ class TestAgentStrategyParameterType:
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
],
)
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
def test_cast_value_calls_external_function(self, mocker: MockerFixture, enum_member, value) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
return_value="casted",
@ -90,7 +91,7 @@ class TestAgentStrategyParameterType:
mock_func.assert_called_once_with(enum_member, value)
assert result == "casted"
def test_cast_value_propagates_exception(self, mocker) -> None:
def test_cast_value_propagates_exception(self, mocker: MockerFixture) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
@ -136,7 +137,7 @@ class TestAgentStrategyParameter:
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
def test_init_frontend_parameter_calls_external(self, mocker: MockerFixture) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
return_value="frontend",
@ -153,7 +154,7 @@ class TestAgentStrategyParameter:
mock_func.assert_called_once_with(param, param.type, "value")
assert result == "frontend"
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
def test_init_frontend_parameter_propagates_exception(self, mocker: MockerFixture) -> None:
mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
side_effect=RuntimeError("error"),

View File

@ -10,7 +10,7 @@ class TestGetParametersFromFeatureDict:
"""Test suite for get_parameters_from_feature_dict"""
@pytest.fixture
def mock_config(self, monkeypatch):
def mock_config(self, monkeypatch: pytest.MonkeyPatch):
"""Mock dify_config values"""
mock = MagicMock()
mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1
@ -23,7 +23,7 @@ class TestGetParametersFromFeatureDict:
return mock
@pytest.fixture
def mock_default_file_limits(self, monkeypatch):
def mock_default_file_limits(self, monkeypatch: pytest.MonkeyPatch):
"""Mock DEFAULT_FILE_NUMBER_LIMITS constant"""
monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99)
return 99

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.common.sensitive_word_avoidance.manager import (
SensitiveWordAvoidanceConfigManager,
@ -26,7 +27,7 @@ class TestSensitiveWordAvoidanceConfigManagerConvert:
# Assert
assert result is None
def test_convert_returns_entity_when_enabled(self, mocker):
def test_convert_returns_entity_when_enabled(self, mocker: MockerFixture):
# Arrange
mock_entity = MagicMock()
mocker.patch(
@ -48,7 +49,7 @@ class TestSensitiveWordAvoidanceConfigManagerConvert:
# Assert
assert result == mock_entity
def test_convert_enabled_without_type_or_config(self, mocker):
def test_convert_enabled_without_type_or_config(self, mocker: MockerFixture):
# Arrange
mock_entity = MagicMock()
patched = mocker.patch(
@ -135,7 +136,7 @@ class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults:
with pytest.raises(ValueError, match="must be a dict"):
SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config)
def test_validate_calls_moderation_factory(self, mocker):
def test_validate_calls_moderation_factory(self, mocker: MockerFixture):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
@ -159,7 +160,7 @@ class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults:
assert result_config["sensitive_word_avoidance"]["enabled"] is True
assert fields == ["sensitive_word_avoidance"]
def test_validate_sets_empty_dict_when_config_none(self, mocker):
def test_validate_sets_empty_dict_when_config_none(self, mocker: MockerFixture):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"
@ -179,7 +180,7 @@ class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults:
# Assert
mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={})
def test_validate_only_structure_validate_skips_factory(self, mocker):
def test_validate_only_structure_validate_skips_factory(self, mocker: MockerFixture):
# Arrange
mock_validate = mocker.patch(
"core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config"

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@ -84,7 +85,7 @@ class TestAgentConfigManagerConvert:
assert result.strategy.name == "CHAIN_OF_THOUGHT"
def test_convert_skips_disabled_tools(self, mocker, base_config):
def test_convert_skips_disabled_tools(self, mocker: MockerFixture, base_config):
# Patch AgentEntity to bypass pydantic validation
mock_agent_entity = mocker.patch(
"core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity",
@ -128,7 +129,7 @@ class TestAgentConfigManagerConvert:
mock_validate.assert_called_once()
mock_agent_entity.assert_called_once()
def test_convert_tool_requires_minimum_keys(self, mocker, base_config):
def test_convert_tool_requires_minimum_keys(self, mocker: MockerFixture, base_config):
mock_validate = mocker.patch(
"core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate",
return_value=MagicMock(),

View File

@ -2,6 +2,7 @@ import uuid
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.entities.agent_entities import PlanningStrategy
@ -69,7 +70,7 @@ class TestDatasetConfigManagerConvert:
assert result.dataset_ids == [valid_uuid]
assert result.retrieve_config.query_variable == "query"
def test_convert_single_with_metadata_configs(self, valid_uuid, mocker):
def test_convert_single_with_metadata_configs(self, valid_uuid, mocker: MockerFixture):
mock_retrieve_config = MagicMock()
mock_entity = MagicMock()
mock_entity.dataset_ids = [valid_uuid]
@ -258,7 +259,7 @@ class TestExtractDatasetConfig:
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_invalid_uuid(self, mocker):
def test_extract_invalid_uuid(self, mocker: MockerFixture):
invalid_uuid = "not-a-uuid"
config = {
"agent_mode": {
@ -270,7 +271,7 @@ class TestExtractDatasetConfig:
with pytest.raises(ValueError):
DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config)
def test_extract_dataset_not_exists(self, valid_uuid, mocker):
def test_extract_dataset_not_exists(self, valid_uuid, mocker: MockerFixture):
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=None,
@ -292,7 +293,7 @@ class TestExtractDatasetConfig:
class TestIsDatasetExists:
def test_dataset_exists_true(self, mocker, valid_uuid):
def test_dataset_exists_true(self, mocker: MockerFixture, valid_uuid):
mock_dataset = MagicMock()
mock_dataset.tenant_id = "tenant1"
mocker.patch(
@ -302,14 +303,14 @@ class TestIsDatasetExists:
assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid):
def test_dataset_exists_false_when_not_found(self, mocker: MockerFixture, valid_uuid):
mocker.patch(
"core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset",
return_value=None,
)
assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid)
def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid):
def test_dataset_exists_false_when_tenant_mismatch(self, mocker: MockerFixture, valid_uuid):
mock_dataset = MagicMock()
mock_dataset.tenant_id = "other"
mocker.patch(

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.entities.model_entities import ModelStatus
@ -16,7 +17,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
class TestModelConfigConverter:
@pytest.fixture(autouse=True)
def patch_response_entity(self, mocker):
def patch_response_entity(self, mocker: MockerFixture):
"""
Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation
and return a simple namespace object instead.
@ -69,7 +70,7 @@ class TestModelConfigConverter:
return bundle
@pytest.fixture
def patch_provider_manager(self, mocker, mock_provider_bundle):
def patch_provider_manager(self, mocker: MockerFixture, mock_provider_bundle):
mock_manager = MagicMock()
mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle
mocker.patch(
@ -99,7 +100,7 @@ class TestModelConfigConverter:
assert result.parameters == {"temperature": 0.7}
assert result.stop == ["\n"]
def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker):
def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker: MockerFixture):
mock_app_config.model.mode = None
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
@ -116,7 +117,9 @@ class TestModelConfigConverter:
result = ModelConfigConverter.convert(mock_app_config)
assert result.mode == LLMMode.COMPLETION
def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker):
def test_convert_mode_from_schema_invalid_fallback(
self, mock_app_config, mock_provider_bundle, mocker: MockerFixture
):
mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = {
ModelPropertyKey.MODE: "invalid"
}
@ -135,7 +138,7 @@ class TestModelConfigConverter:
# Credential Errors
# =============================
def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker: MockerFixture):
mock_provider_bundle.configuration.get_current_credentials.return_value = None
mock_manager = MagicMock()
@ -152,7 +155,7 @@ class TestModelConfigConverter:
# Provider Model Errors
# =============================
def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker: MockerFixture):
mock_provider_bundle.configuration.get_provider_model.return_value = None
mock_manager = MagicMock()
@ -174,7 +177,7 @@ class TestModelConfigConverter:
],
)
def test_convert_provider_model_status_errors(
self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception
self, mock_app_config, mock_provider_bundle, mocker: MockerFixture, status, expected_exception
):
mock_provider = MagicMock()
mock_provider.status = status
@ -194,7 +197,7 @@ class TestModelConfigConverter:
# Schema Errors
# =============================
def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker):
def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker: MockerFixture):
mock_provider_bundle.model_type_instance.get_model_schema.return_value = None
mock_manager = MagicMock()

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
# Target
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
@ -107,7 +108,9 @@ class TestModelConfigManager:
# validate_and_set_defaults
# ==========================================================
def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list):
def test_validate_and_set_defaults_success(
self, mocker: MockerFixture, valid_config, provider_entities, valid_model_list
):
self._patch_model_assembly(
mocker,
provider_entities=provider_entities,
@ -127,35 +130,37 @@ class TestModelConfigManager:
with pytest.raises(ValueError, match="object type"):
ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"})
def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities):
def test_validate_and_set_defaults_missing_provider(self, mocker: MockerFixture, provider_entities):
config = {"model": {"name": "gpt-4", "completion_params": {}}}
self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[])
with pytest.raises(ValueError, match="model.provider is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities):
def test_validate_and_set_defaults_invalid_provider(self, mocker: MockerFixture, provider_entities):
config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}}
self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[])
with pytest.raises(ValueError, match="model.provider is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities):
def test_validate_and_set_defaults_missing_name(self, mocker: MockerFixture, provider_entities):
config = {"model": {"provider": "openai/gpt", "completion_params": {}}}
self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[])
with pytest.raises(ValueError, match="model.name is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities):
def test_validate_and_set_defaults_empty_models(self, mocker: MockerFixture, provider_entities):
config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}}
self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[])
with pytest.raises(ValueError, match="must be in the specified model list"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list):
def test_validate_and_set_defaults_invalid_model_name(
self, mocker: MockerFixture, provider_entities, valid_model_list
):
config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}}
self._patch_model_assembly(
mocker,
@ -166,7 +171,7 @@ class TestModelConfigManager:
with pytest.raises(ValueError, match="must be in the specified model list"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities):
def test_validate_and_set_defaults_default_mode_when_missing(self, mocker: MockerFixture, provider_entities):
model = MagicMock()
model.model = "gpt-4"
model.model_properties = {}
@ -178,7 +183,9 @@ class TestModelConfigManager:
assert updated_config["model"]["mode"] == "completion"
def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list):
def test_validate_and_set_defaults_missing_completion_params(
self, mocker: MockerFixture, provider_entities, valid_model_list
):
config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}}
self._patch_model_assembly(
mocker,
@ -189,7 +196,7 @@ class TestModelConfigManager:
with pytest.raises(ValueError, match="completion_params is required"):
ModelConfigManager.validate_and_set_defaults("tenant1", config)
def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list):
def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker: MockerFixture, valid_model_list):
"""
Covers branch where provider does not contain '/' and
ModelProviderID conversion is triggered (line 64).

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.easy_ui_based_app.prompt_template.manager import (
PromptTemplateConfigManager,
@ -38,7 +39,7 @@ class TestPromptTemplateConfigManagerConvert:
with pytest.raises(ValueError, match="prompt_type is required"):
PromptTemplateConfigManager.convert({})
def test_convert_simple_prompt(self, mocker):
def test_convert_simple_prompt(self, mocker: MockerFixture):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
@ -56,7 +57,7 @@ class TestPromptTemplateConfigManagerConvert:
assert result == "simple_entity"
mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello")
def test_convert_advanced_chat_valid(self, mocker):
def test_convert_advanced_chat_valid(self, mocker: MockerFixture):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mock_prompt_entity_cls.return_value = "advanced_entity"
@ -97,7 +98,7 @@ class TestPromptTemplateConfigManagerConvert:
{"text": "hi", "role": 123},
],
)
def test_convert_advanced_invalid_message_fields(self, mocker, message):
def test_convert_advanced_invalid_message_fields(self, mocker: MockerFixture, message):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
@ -114,7 +115,7 @@ class TestPromptTemplateConfigManagerConvert:
with pytest.raises(ValueError):
PromptTemplateConfigManager.convert(config)
def test_convert_advanced_completion_with_roles(self, mocker):
def test_convert_advanced_completion_with_roles(self, mocker: MockerFixture):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mock_prompt_entity_cls.return_value = "advanced_entity"
@ -154,7 +155,7 @@ class TestValidateAndSetDefaults:
def setup_method(self):
self.valid_model = {"mode": "chat"}
def _patch_prompt_type(self, mocker):
def _patch_prompt_type(self, mocker: MockerFixture):
mock_prompt_entity_cls = MagicMock()
mock_prompt_entity_cls.PromptType = DummyPromptType()
mocker.patch(
@ -163,7 +164,7 @@ class TestValidateAndSetDefaults:
)
return mock_prompt_entity_cls
def test_default_prompt_type_set(self, mocker):
def test_default_prompt_type_set(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {"model": self.valid_model}
@ -173,7 +174,7 @@ class TestValidateAndSetDefaults:
assert result["prompt_type"] == "simple"
assert isinstance(keys, list)
def test_invalid_prompt_type_raises(self, mocker):
def test_invalid_prompt_type_raises(self, mocker: MockerFixture):
class InvalidEnum(DummyPromptType):
def __iter__(self):
return iter([DummyEnumValue("valid")])
@ -191,7 +192,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_invalid_chat_prompt_config_type(self, mocker):
def test_invalid_chat_prompt_config_type(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {
@ -203,7 +204,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_simple_mode_invalid_pre_prompt_type(self, mocker):
def test_simple_mode_invalid_pre_prompt_type(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {
@ -215,7 +216,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_requires_one_config(self, mocker):
def test_advanced_requires_one_config(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {
@ -228,7 +229,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_invalid_model_mode(self, mocker):
def test_advanced_invalid_model_mode(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {
@ -240,7 +241,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_advanced_chat_prompt_length_exceeds(self, mocker):
def test_advanced_chat_prompt_length_exceeds(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {
@ -252,7 +253,7 @@ class TestValidateAndSetDefaults:
with pytest.raises(ValueError):
PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config)
def test_completion_prefix_defaults_set_when_empty(self, mocker):
def test_completion_prefix_defaults_set_when_empty(self, mocker: MockerFixture):
self._patch_prompt_type(mocker)
config = {

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.easy_ui_based_app.variables.manager import (
BasicVariablesConfigManager,
@ -15,7 +16,7 @@ class TestBasicVariablesConfigManagerConvert:
assert variables == []
assert external == []
def test_convert_external_data_tools_enabled_and_disabled(self, mocker):
def test_convert_external_data_tools_enabled_and_disabled(self, mocker: MockerFixture):
config = {
"external_data_tools": [
{"enabled": False},
@ -232,7 +233,7 @@ class TestValidateExternalDataToolsAndSetDefaults:
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
def test_validate_disabled_tool_skipped(self, mocker):
def test_validate_disabled_tool_skipped(self, mocker: MockerFixture):
config = {"external_data_tools": [{"enabled": False}]}
spy = mocker.patch(
@ -250,7 +251,7 @@ class TestValidateExternalDataToolsAndSetDefaults:
with pytest.raises(ValueError):
BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config)
def test_validate_enabled_tool_calls_factory(self, mocker):
def test_validate_enabled_tool_calls_factory(self, mocker: MockerFixture):
config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]}
spy = mocker.patch(
@ -263,7 +264,7 @@ class TestValidateExternalDataToolsAndSetDefaults:
class TestValidateAndSetDefaultsIntegration:
def test_validate_and_set_defaults_calls_both(self, mocker):
def test_validate_and_set_defaults_calls_both(self, mocker: MockerFixture):
config = {}
spy_var = mocker.patch.object(

View File

@ -2,6 +2,7 @@ from collections import UserDict
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -12,7 +13,7 @@ class TestBaseAppConfigManager:
return {"key": "value", "another": 123}
@pytest.fixture
def mock_app_additional_features(self, mocker):
def mock_app_additional_features(self, mocker: MockerFixture):
mock_instance = MagicMock()
mocker.patch(
"core.app.app_config.base_app_config_manager.AppAdditionalFeatures",
@ -21,7 +22,7 @@ class TestBaseAppConfigManager:
return mock_instance
@pytest.fixture
def mock_managers(self, mocker):
def mock_managers(self, mocker: MockerFixture):
retrieval = mocker.patch(
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
return_value="retrieval_result",
@ -72,7 +73,7 @@ class TestBaseAppConfigManager:
)
def test_convert_features_all_modes(
self,
mocker,
mocker: MockerFixture,
mock_config_dict,
mock_app_additional_features,
mock_managers,
@ -107,7 +108,7 @@ class TestBaseAppConfigManager:
mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items()))
mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items()))
def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers):
def test_convert_features_empty_config(self, mocker: MockerFixture, mock_app_additional_features, mock_managers):
# Arrange
empty_config = {}
mock_app_mode = MagicMock()
@ -143,7 +144,7 @@ class TestBaseAppConfigManager:
with pytest.raises((TypeError, AttributeError)):
BaseAppConfigManager.convert_features(invalid_config, "CHAT")
def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict):
def test_convert_features_manager_exception_propagates(self, mocker: MockerFixture, mock_config_dict):
# Arrange
mocker.patch(
"core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert",
@ -154,7 +155,9 @@ class TestBaseAppConfigManager:
with pytest.raises(RuntimeError):
BaseAppConfigManager.convert_features(mock_config_dict, "CHAT")
def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers):
def test_convert_features_mapping_subclass(
self, mocker: MockerFixture, mock_app_additional_features, mock_managers
):
# Arrange
class CustomMapping(UserDict):
pass

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.workflow_ui_based_app.variables.manager import (
WorkflowVariablesConfigManager,
@ -10,19 +11,19 @@ from core.app.app_config.workflow_ui_based_app.variables.manager import (
@pytest.fixture
def mock_workflow(mocker):
def mock_workflow(mocker: MockerFixture):
workflow = mocker.MagicMock()
workflow.graph_dict = {"nodes": []}
return workflow
@pytest.fixture
def mock_variable_entity(mocker):
def mock_variable_entity(mocker: MockerFixture):
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity")
@pytest.fixture
def mock_rag_entity(mocker):
def mock_rag_entity(mocker: MockerFixture):
return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity")

View File

@ -111,7 +111,7 @@ class TestAdvancedChatAppGeneratorInternals:
workflow_id="workflow-id",
)
def test_generate_loads_conversation_and_files(self, monkeypatch):
def test_generate_loads_conversation_and_files(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
app_config = self._build_app_config()
@ -195,7 +195,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert captured["application_generate_entity"].files == built_files
assert build_files_called["called"] is True
def test_resume_delegates_to_generate(self, monkeypatch):
def test_resume_delegates_to_generate(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
@ -235,7 +235,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert result == {"resumed": True}
assert captured["graph_runtime_state"] is not None
def test_single_iteration_generate_builds_debug_task(self, monkeypatch):
def test_single_iteration_generate_builds_debug_task(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
app_config = self._build_app_config()
captured: dict[str, object] = {}
@ -293,7 +293,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1"
def test_single_loop_generate_builds_debug_task(self, monkeypatch):
def test_single_loop_generate_builds_debug_task(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
app_config = self._build_app_config()
captured: dict[str, object] = {}
@ -351,7 +351,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_loop_run.node_id == "node-2"
def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch):
def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 0
app_config = self._build_app_config()
@ -449,7 +449,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert isinstance(captured["conversation"], ConversationSnapshot)
assert isinstance(captured["message"], MessageSnapshot)
def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch):
def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 0
app_config = self._build_app_config()
@ -535,7 +535,7 @@ class TestAdvancedChatAppGeneratorInternals:
db_session.refresh.assert_not_called()
db_session.close.assert_called_once()
def test_generate_worker_raises_when_workflow_not_found(self, monkeypatch):
def test_generate_worker_raises_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
app_config = self._build_app_config()
@ -594,7 +594,7 @@ class TestAdvancedChatAppGeneratorInternals:
graph_runtime_state=None,
)
def test_generate_worker_raises_when_app_not_found_for_internal_call(self, monkeypatch):
def test_generate_worker_raises_when_app_not_found_for_internal_call(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
app_config = self._build_app_config()
@ -658,7 +658,7 @@ class TestAdvancedChatAppGeneratorInternals:
graph_runtime_state=None,
)
def test_generate_worker_handles_stopped_error(self, monkeypatch):
def test_generate_worker_handles_stopped_error(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
app_config = self._build_app_config()
@ -732,7 +732,7 @@ class TestAdvancedChatAppGeneratorInternals:
queue_manager.publish_error.assert_not_called()
def test_generate_worker_handles_validation_error(self, monkeypatch):
def test_generate_worker_handles_validation_error(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
app_config = self._build_app_config()
@ -816,7 +816,7 @@ class TestAdvancedChatAppGeneratorInternals:
queue_manager.publish_error.assert_called_once()
def test_generate_worker_handles_value_and_unknown_errors(self, monkeypatch):
def test_generate_worker_handles_value_and_unknown_errors(self, monkeypatch: pytest.MonkeyPatch):
app_config = self._build_app_config()
@contextmanager
@ -897,7 +897,7 @@ class TestAdvancedChatAppGeneratorInternals:
queue_manager.publish_error.assert_called_once()
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
def test_handle_response_closed_file_raises_stopped(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
@ -953,7 +953,7 @@ class TestAdvancedChatAppGeneratorInternals:
stream=False,
)
def test_handle_response_re_raises_value_error(self, monkeypatch):
def test_handle_response_re_raises_value_error(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
app_config = self._build_app_config()
@ -1002,7 +1002,7 @@ class TestAdvancedChatAppGeneratorInternals:
logger_exception.assert_called_once()
def test_generate_worker_handles_invoke_auth_error(self, monkeypatch):
def test_generate_worker_handles_invoke_auth_error(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1
@ -1088,7 +1088,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert queue_manager.publish_error.called
def test_generate_debugger_enables_retrieve_source(self, monkeypatch):
def test_generate_debugger_enables_retrieve_source(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
app_config = WorkflowUIBasedAppConfig(
@ -1167,7 +1167,7 @@ class TestAdvancedChatAppGeneratorInternals:
assert app_config.additional_features.show_retrieve_source is True
assert captured["application_generate_entity"].query == "hello"
def test_generate_service_api_sets_parent_message_id(self, monkeypatch):
def test_generate_service_api_sets_parent_message_id(self, monkeypatch: pytest.MonkeyPatch):
generator = AdvancedChatAppGenerator()
app_config = WorkflowUIBasedAppConfig(

View File

@ -224,7 +224,7 @@ class TestAdvancedChatGenerateTaskPipeline:
assert isinstance(responses[0], ValueError)
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")),
@ -368,7 +368,7 @@ class TestAdvancedChatGenerateTaskPipeline:
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
def test_workflow_finish_handlers(self, monkeypatch):
def test_workflow_finish_handlers(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
@ -593,7 +593,7 @@ class TestAdvancedChatGenerateTaskPipeline:
assert message.answer == "hello"
assert message.message_metadata
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._message_end_to_stream_response = lambda: "end"
saved: list[str] = []
@ -614,7 +614,7 @@ class TestAdvancedChatGenerateTaskPipeline:
assert responses == ["end"]
assert saved == ["saved"]
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),

View File

@ -2,6 +2,7 @@ import uuid
from types import SimpleNamespace
import pytest
from pytest_mock import MockerFixture
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.agent_chat.app_config_manager import (
@ -11,7 +12,7 @@ from core.entities.agent_entities import PlanningStrategy
class TestAgentChatAppConfigManagerGetAppConfig:
def test_get_app_config_override_config(self, mocker):
def test_get_app_config_override_config(self, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"ignored": True}
@ -45,7 +46,7 @@ class TestAgentChatAppConfigManagerGetAppConfig:
assert result.variables == "variables"
assert result.external_data_variables == "external"
def test_get_app_config_conversation_specific(self, mocker):
def test_get_app_config_conversation_specific(self, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
@ -76,7 +77,7 @@ class TestAgentChatAppConfigManagerGetAppConfig:
assert result.app_model_config_dict == app_model_config.to_dict.return_value
assert result.app_model_config_from.value == "conversation-specific-config"
def test_get_app_config_latest_config(self, mocker):
def test_get_app_config_latest_config(self, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
@ -107,7 +108,7 @@ class TestAgentChatAppConfigManagerGetAppConfig:
class TestAgentChatAppConfigManagerConfigValidate:
def test_config_validate_filters_related_keys(self, mocker):
def test_config_validate_filters_related_keys(self, mocker: MockerFixture):
config = {
"model": {},
"user_input_form": {},
@ -247,7 +248,7 @@ class TestValidateAgentModeAndSetDefaults:
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
)
def test_old_tool_dataset_id_not_exists(self, mocker):
def test_old_tool_dataset_id_not_exists(self, mocker: MockerFixture):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=False,
@ -275,7 +276,7 @@ class TestValidateAgentModeAndSetDefaults:
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
)
def test_valid_old_and_new_style_tools(self, mocker):
def test_valid_old_and_new_style_tools(self, mocker: MockerFixture):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=True,

View File

@ -2,6 +2,7 @@ import contextlib
import pytest
from pydantic import ValidationError
from pytest_mock import MockerFixture
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.exc import GenerateTaskStoppedError
@ -16,7 +17,7 @@ class DummyAccount:
@pytest.fixture
def generator(mocker):
def generator(mocker: MockerFixture):
gen = AgentChatAppGenerator()
mocker.patch(
"core.app.apps.agent_chat.app_generator.current_app",
@ -27,19 +28,19 @@ def generator(mocker):
class TestAgentChatAppGeneratorGenerate:
def test_generate_rejects_blocking_mode(self, generator, mocker):
def test_generate_rejects_blocking_mode(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
def test_generate_requires_query(self, generator, mocker):
def test_generate_requires_query(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
def test_generate_rejects_non_string_query(self, generator, mocker):
def test_generate_rejects_non_string_query(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
@ -50,7 +51,7 @@ class TestAgentChatAppGeneratorGenerate:
invoke_from=mocker.MagicMock(),
)
def test_generate_override_requires_debugger(self, generator, mocker):
def test_generate_override_requires_debugger(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock()
user = DummyAccount("user")
@ -62,7 +63,7 @@ class TestAgentChatAppGeneratorGenerate:
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_success_with_debugger_override(self, generator, mocker):
def test_generate_success_with_debugger_override(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
@ -142,7 +143,7 @@ class TestAgentChatAppGeneratorGenerate:
assert result == {"result": "ok"}
thread_obj.start.assert_called_once()
def test_generate_without_file_config(self, generator, mocker):
def test_generate_without_file_config(self, generator, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
@ -213,14 +214,14 @@ class TestAgentChatAppGeneratorGenerate:
class TestAgentChatAppGeneratorWorker:
@pytest.fixture(autouse=True)
def patch_context(self, mocker):
def patch_context(self, mocker: MockerFixture):
@contextlib.contextmanager
def ctx_manager(*args, **kwargs):
yield
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker: MockerFixture):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
@ -250,7 +251,7 @@ class TestAgentChatAppGeneratorWorker:
Exception("bad"),
],
)
def test_generate_worker_publishes_errors(self, generator, mocker, error):
def test_generate_worker_publishes_errors(self, generator, mocker: MockerFixture, error):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
@ -271,7 +272,7 @@ class TestAgentChatAppGeneratorWorker:
assert queue_manager.publish_error.called
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker: MockerFixture):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.agent.entities import AgentEntity
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
@ -13,7 +14,7 @@ def runner():
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner, mocker):
def test_run_app_not_found(self, runner, mocker: MockerFixture):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
@ -22,7 +23,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_moderation_error_direct_output(self, runner, mocker):
def test_run_moderation_error_direct_output(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -45,7 +46,7 @@ class TestAgentChatAppRunnerRun:
runner.direct_output.assert_called_once()
def test_run_annotation_reply_short_circuits(self, runner, mocker):
def test_run_annotation_reply_short_circuits(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -74,7 +75,7 @@ class TestAgentChatAppRunnerRun:
queue_manager.publish.assert_called_once()
runner.direct_output.assert_called_once()
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
def test_run_hosting_moderation_short_circuits(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -98,7 +99,7 @@ class TestAgentChatAppRunnerRun:
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_model_schema_missing(self, runner, mocker):
def test_run_model_schema_missing(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -140,7 +141,7 @@ class TestAgentChatAppRunnerRun:
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
],
)
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
def test_run_chain_of_thought_modes(self, runner, mocker: MockerFixture, mode, expected_runner):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -196,7 +197,7 @@ class TestAgentChatAppRunnerRun:
runner_instance.run.assert_called_once()
runner._handle_invoke_result.assert_called_once()
def test_run_invalid_llm_mode_raises(self, runner, mocker):
def test_run_invalid_llm_mode_raises(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -242,7 +243,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -298,7 +299,7 @@ class TestAgentChatAppRunnerRun:
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
runner_instance.run.assert_called_once()
def test_run_conversation_not_found(self, runner, mocker):
def test_run_conversation_not_found(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
@ -332,7 +333,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_message_not_found(self, runner, mocker):
def test_run_message_not_found(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
@ -366,7 +367,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
def test_run_invalid_agent_strategy_raises(self, runner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
import core.app.apps.completion.app_runner as module
from core.app.apps.completion.app_runner import CompletionAppRunner
@ -47,7 +48,7 @@ def _build_generate_entity(app_config, file_upload_config=None):
class TestCompletionAppRunner:
def test_run_app_not_found(self, runner, mocker):
def test_run_app_not_found(self, runner, mocker: MockerFixture):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
@ -58,7 +59,7 @@ class TestCompletionAppRunner:
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
def test_run_moderation_error_outputs_direct(self, runner, mocker):
def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
@ -78,7 +79,7 @@ class TestCompletionAppRunner:
runner.direct_output.assert_called_once()
runner._handle_invoke_result.assert_not_called()
def test_run_hosting_moderation_stops(self, runner, mocker):
def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
@ -97,7 +98,7 @@ class TestCompletionAppRunner:
runner._handle_invoke_result.assert_not_called()
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
@ -140,7 +141,7 @@ class TestCompletionAppRunner:
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
runner._handle_invoke_result.assert_called_once()
def test_run_uses_low_image_detail_default(self, runner, mocker):
def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()

Some files were not shown because too many files have changed in this diff Show More