diff --git a/api/libs/helper.py b/api/libs/helper.py index ac85e88ef7f..7066f9eab45 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload, over from uuid import UUID from zoneinfo import available_timezones -from flask import Response, stream_with_context +from flask import Request, Response, stream_with_context from flask_restx import fields from pydantic import BaseModel, ConfigDict, TypeAdapter, with_config from pydantic.functional_validators import AfterValidator @@ -167,19 +167,6 @@ def build_avatar_url(avatar: str | None) -> str | None: return file_helpers.get_signed_file_url(avatar) -class AvatarUrlField(fields.Raw): - @override - def output(self, key, obj, **kwargs): - if obj is None: - return None - - from models import Account - - if isinstance(obj, Account) and obj.avatar is not None: - return build_avatar_url(obj.avatar) - return None - - class TimestampField(fields.Raw): @override def schema(self) -> dict[str, object]: @@ -397,7 +384,7 @@ def generate_string(n): return result -def extract_remote_ip(request) -> str: +def extract_remote_ip(request: Request) -> str: if request.headers.get("CF-Connecting-IP"): return cast(str, request.headers.get("CF-Connecting-IP")) elif request.headers.getlist("X-Forwarded-For"): diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index ac09060e9d4..12f91212c1f 100644 --- a/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -42,7 +42,9 @@ def trace_client_factory(): class TestTraceClient: @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") - def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): + def test_init( + self, mock_gethostname: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient] + ): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +58,7 @@ class TestTraceClient: assert client.done is True @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_export(self, mock_exporter_class, trace_client_factory): + def test_export(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") spans = [MagicMock(spec=ReadableSpan)] @@ -65,7 +67,9 @@ class TestTraceClient: @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): + def test_api_check_success( + self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient] + ): mock_response = MagicMock() mock_response.status_code = 405 mock_head.return_value = mock_response @@ -75,7 +79,9 @@ class TestTraceClient: @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): + def test_api_check_failure_status( + self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient] + ): mock_response = MagicMock() mock_response.status_code = 500 mock_head.return_value = mock_response @@ -85,7 +91,9 @@ class TestTraceClient: @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): + def test_api_check_exception( + self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient] + ): mock_head.side_effect = httpx.RequestError("Connection error") client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -93,12 +101,12 @@ class TestTraceClient: client.api_check() @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_get_project_url(self, mock_exporter_class, trace_client_factory): + def test_get_project_url(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_add_span(self, mock_exporter_class, trace_client_factory): + def test_add_span(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): client = trace_client_factory( service_name="test-service", endpoint="http://test-endpoint", @@ -135,7 +143,9 @@ class TestTraceClient: @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") @patch("dify_trace_aliyun.data_exporter.traceclient.logger") - def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): + def test_add_span_queue_full( + self, mock_logger: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient] + ): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) span_data = SpanData( @@ -159,7 +169,7 @@ class TestTraceClient: mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_export_batch_error(self, mock_exporter_class, trace_client_factory): + def test_export_batch_error(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -172,13 +182,13 @@ class TestTraceClient: mock_logger.warning.assert_called() @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_worker_loop(self, mock_exporter_class, trace_client_factory): + def test_worker_loop(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. client = trace_client_factory( service_name="test-service", endpoint="http://test-endpoint", - schedule_delay_sec=0.1, + schedule_delay_sec=1, ) with patch.object(client.condition, "wait") as mock_wait: @@ -189,7 +199,7 @@ class TestTraceClient: assert mock_wait.called or client.done @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") - def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): + def test_shutdown_flushes(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 0c2ead2bccc..9e3cac2255b 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -456,7 +456,7 @@ class TestPhoenixParentSpanBridgeHelpers: assert error.parent_node_execution_id == "outer-node-execution-1" assert "outer-node-execution-1" in str(error) - def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch): + def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch: pytest.MonkeyPatch): mock_redis = MagicMock() mock_redis.get.return_value = '{"tracestate": "vendor=value"}' monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis) diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index f19bcefd457..76d4c99caf7 100644 --- a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -656,7 +656,7 @@ def _patch_workflow_trace_deps(monkeypatch, trace_instance): trace_instance.add_run = MagicMock() -def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch): +def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch: pytest.MonkeyPatch): """Chatflow with external trace_id: LangSmith trace_id must be message_id, not external.""" trace_info = _make_workflow_trace_info( message_id="msg-abc", @@ -677,7 +677,7 @@ def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypa assert trace_info.metadata.get("external_trace_id") == "external-999" -def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch): +def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch: pytest.MonkeyPatch): """Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id.""" trace_info = _make_workflow_trace_info( message_id=None, diff --git a/api/tests/integration_tests/controllers/openapi/test_app_run.py b/api/tests/integration_tests/controllers/openapi/test_app_run.py index 92e2e993dbf..b4f383a7cea 100644 --- a/api/tests/integration_tests/controllers/openapi/test_app_run.py +++ b/api/tests/integration_tests/controllers/openapi/test_app_run.py @@ -13,7 +13,9 @@ from extensions.ext_database import db from models import App -def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch): +def test_run_chat_dispatches_to_chat_handler( + flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch +): captured = {} def _fake_generate(*, app_model, user, args, invoke_from, streaming): @@ -78,7 +80,9 @@ def app_with_mode(flask_app: Flask, workspace_account): db.session.commit() -def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch): +def test_run_chat_without_query_returns_422( + flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch +): client = flask_app.test_client() res = client.post( f"/openapi/v1/apps/{app_in_workspace.id}/run", @@ -89,7 +93,9 @@ def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_wor assert b"query_required_for_chat" in res.data -def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch): +def test_run_completion_dispatches_to_completion_handler( + flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch +): app = app_with_mode("completion") captured: dict = {} @@ -119,7 +125,9 @@ def test_run_completion_dispatches_to_completion_handler(flask_app, account_toke assert captured["mode"] == "completion" -def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch): +def test_run_workflow_with_query_returns_422( + flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch +): app = app_with_mode("workflow") client = flask_app.test_client() res = client.post( @@ -131,7 +139,9 @@ def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_ assert b"query_not_supported_for_workflow" in res.data -def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch): +def test_run_workflow_no_query_dispatches_to_workflow_handler( + flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch +): app = app_with_mode("workflow") def _fake_generate(*, app_model, user, args, invoke_from, streaming): @@ -154,7 +164,9 @@ def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account assert body["workflow_run_id"] == "wfr" -def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch): +def test_run_unsupported_mode_returns_422( + flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch +): app = app_with_mode("channel") client = flask_app.test_client() res = client.post( @@ -166,7 +178,7 @@ def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mod assert b"mode_not_runnable" in res.data -def test_run_without_bearer_returns_401(flask_app, app_in_workspace): +def test_run_without_bearer_returns_401(flask_app: Flask, app_in_workspace): client = flask_app.test_client() res = client.post( f"/openapi/v1/apps/{app_in_workspace.id}/run", @@ -175,7 +187,9 @@ def test_run_without_bearer_returns_401(flask_app, app_in_workspace): assert res.status_code == 401 -def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch): +def test_run_with_insufficient_scope_returns_403( + flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch +): """Stub the authenticator to return an AuthContext with empty scopes.""" from libs import oauth_bearer @@ -198,7 +212,7 @@ def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_i assert res.status_code == 403 -def test_run_with_unknown_app_returns_404(flask_app, account_token): +def test_run_with_unknown_app_returns_404(flask_app: Flask, account_token): client = flask_app.test_client() res = client.post( f"/openapi/v1/apps/{uuid.uuid4()}/run", @@ -208,7 +222,9 @@ def test_run_with_unknown_app_returns_404(flask_app, account_token): assert res.status_code == 404 -def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch): +def test_run_streaming_returns_event_stream( + flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch +): def _stream() -> Generator[str, None, None]: yield 'event: message\ndata: {"x": 1}\n\n' @@ -228,7 +244,7 @@ def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_wor assert b"event: message" in res.data -def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace): +def test_run_without_inputs_returns_422(flask_app: Flask, account_token, app_in_workspace): client = flask_app.test_client() res = client.post( f"/openapi/v1/apps/{app_in_workspace.id}/run", diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5b7790f6f44..d8a0a713f12 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,6 +4,8 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +import pytest + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance @@ -91,7 +93,7 @@ def init_llm_node(config: dict) -> LLMNode: return node -def _mock_db_session_close(monkeypatch) -> None: +def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(db.session, "close", MagicMock()) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fc230a2a68d..44708285037 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,6 +3,8 @@ import time import uuid from unittest.mock import MagicMock +import pytest + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer @@ -83,11 +85,11 @@ def init_parameter_extractor_node(config: dict, memory=None): return node -def _mock_db_session_close(monkeypatch) -> None: +def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(db.session, "close", MagicMock()) -def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch): +def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch): """ Test function calling for parameter extractor. """ @@ -128,7 +130,7 @@ def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch): assert result.outputs.get("__reason") == None -def test_instructions(setup_model_mock, monkeypatch): +def test_instructions(setup_model_mock, monkeypatch: pytest.MonkeyPatch): """ Test chat parameter extractor. """ @@ -178,7 +180,7 @@ def test_instructions(setup_model_mock, monkeypatch): assert "what's the weather in SF" in prompt.get("text") -def test_chat_parameter_extractor(setup_model_mock, monkeypatch): +def test_chat_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch): """ Test chat parameter extractor. """ @@ -229,7 +231,7 @@ def test_chat_parameter_extractor(setup_model_mock, monkeypatch): assert '\n{"type": "object"' in prompt.get("text") -def test_completion_parameter_extractor(setup_model_mock, monkeypatch): +def test_completion_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch): """ Test completion parameter extractor. """ @@ -354,7 +356,7 @@ def test_extract_json_from_tool_call(): assert result["location"] == "kawaii" -def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): +def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch: pytest.MonkeyPatch): """ Test chat parameter extractor with memory. """ diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 3c496d1fc8c..9ef6b903066 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -517,11 +517,11 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.TenantService") def test_should_handle_account_generation_scenarios( self, - mock_tenant_service, - mock_account_service, - mock_register_service, - mock_feature_service, - mock_get_account, + mock_tenant_service: MagicMock, + mock_account_service: MagicMock, + mock_register_service: MagicMock, + mock_feature_service: MagicMock, + mock_get_account: MagicMock, app: Flask, user_info: OAuthUserInfo, mock_account, @@ -562,11 +562,11 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.TenantService") def test_should_register_with_lowercase_email( self, - mock_tenant_service, - mock_account_service, - mock_register_service, - mock_feature_service, - mock_get_account, + mock_tenant_service: MagicMock, + mock_account_service: MagicMock, + mock_register_service: MagicMock, + mock_feature_service: MagicMock, + mock_get_account: MagicMock, app: Flask, ): user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") @@ -593,11 +593,11 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.TenantService") def test_should_register_with_browser_timezone( self, - mock_tenant_service, - mock_account_service, - mock_register_service, - mock_feature_service, - mock_get_account, + mock_tenant_service: MagicMock, + mock_account_service: MagicMock, + mock_register_service: MagicMock, + mock_feature_service: MagicMock, + mock_get_account: MagicMock, app: Flask, user_info: OAuthUserInfo, ): @@ -624,11 +624,11 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.TenantService") def test_should_register_with_state_language( self, - mock_tenant_service, - mock_account_service, - mock_register_service, - mock_feature_service, - mock_get_account, + mock_tenant_service: MagicMock, + mock_account_service: MagicMock, + mock_register_service: MagicMock, + mock_feature_service: MagicMock, + mock_get_account: MagicMock, app: Flask, user_info: OAuthUserInfo, ): diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index 9d588d4c73f..c93e61b2bfb 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest @@ -83,8 +83,8 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_success( self, - mock_encrypter, - mock_factory, + mock_encrypter: MagicMock, + mock_factory: MagicMock, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, @@ -107,7 +107,12 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_validation_failed( - self, mock_factory, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, mock_args + self, + mock_factory: MagicMock, + flask_app_with_containers: Flask, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = False @@ -123,8 +128,8 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_encrypts_api_key( self, - mock_encrypter, - mock_factory, + mock_encrypter: MagicMock, + mock_factory: MagicMock, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, @@ -289,7 +294,7 @@ class TestApiKeyAuthService: ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args): + def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args): mock_factory.side_effect = Exception("Factory error") with pytest.raises(Exception, match="Factory error"): ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index 48830c0f43b..398c5979ef2 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -47,8 +47,8 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.ToolManager") def test_fetches_credentials_with_credential_id( self, - mock_tool_mgr, - mock_encrypter_fn, + mock_tool_mgr: MagicMock, + mock_encrypter_fn: MagicMock, mock_client_cls, flask_app_with_containers: Flask, db_session_with_containers: Session, @@ -90,8 +90,8 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.ToolManager") def test_raises_when_tool_provider_not_found( self, - mock_tool_mgr, - mock_encrypter_fn, + mock_tool_mgr: MagicMock, + mock_encrypter_fn: MagicMock, flask_app_with_containers: Flask, db_session_with_containers: Session, ): diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 66ff24f3741..5d2f4f709f1 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -258,7 +258,7 @@ class TestUpgradePluginWithMarketplace: class TestUpgradePluginWithGithub: @patch("core.plugin.plugin_service.FeatureService") @patch("core.plugin.plugin_service.PluginInstaller") - def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): + def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls: MagicMock, mock_fs: MagicMock): mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.upgrade_plugin.return_value = MagicMock() @@ -273,7 +273,7 @@ class TestUpgradePluginWithGithub: class TestUploadPkg: @patch("core.plugin.plugin_service.FeatureService") @patch("core.plugin.plugin_service.PluginInstaller") - def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): + def test_runs_permission_and_scope_checks(self, mock_installer_cls: MagicMock, mock_fs: MagicMock): mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() upload_resp.verification = None @@ -318,7 +318,7 @@ class TestInstallFromMarketplacePkg: @patch("core.plugin.plugin_service.FeatureService") @patch("core.plugin.plugin_service.PluginInstaller") @patch("core.plugin.plugin_service.dify_config") - def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): + def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls: MagicMock, mock_fs: MagicMock): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value diff --git a/api/tests/unit_tests/commands/test_generate_swagger_markdown_docs.py b/api/tests/unit_tests/commands/test_generate_swagger_markdown_docs.py index f30bbd24374..e989d54a965 100644 --- a/api/tests/unit_tests/commands/test_generate_swagger_markdown_docs.py +++ b/api/tests/unit_tests/commands/test_generate_swagger_markdown_docs.py @@ -5,6 +5,8 @@ import json import sys from pathlib import Path +import pytest + def _load_generate_swagger_markdown_docs_module(): api_dir = Path(__file__).resolve().parents[3] @@ -20,7 +22,9 @@ def _load_generate_swagger_markdown_docs_module(): return module -def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(tmp_path, monkeypatch): +def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): module = _load_generate_swagger_markdown_docs_module() openapi_dir = tmp_path / "openapi" markdown_dir = tmp_path / "markdown" @@ -69,7 +73,9 @@ def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_con assert "FastOpenAPI Preview" not in (markdown_dir / "service-openapi.md").read_text(encoding="utf-8") -def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(tmp_path, monkeypatch): +def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): module = _load_generate_swagger_markdown_docs_module() swagger_dir = tmp_path / "swagger" markdown_dir = tmp_path / "markdown" @@ -105,7 +111,7 @@ def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagg assert not list(swagger_dir.glob("*.json")) -def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path): +def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path: Path): module = _load_generate_swagger_markdown_docs_module() spec_path = tmp_path / "console-openapi.json" spec_path.write_text( @@ -239,7 +245,7 @@ def test_patch_union_schema_markdown_ignores_specs_without_schemas(tmp_path): assert module._patch_union_schema_markdown("unchanged", spec_path) == "unchanged" -def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path): +def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path: Path): module = _load_generate_swagger_markdown_docs_module() spec_path = tmp_path / "console-openapi.json" spec_path.write_text( @@ -285,7 +291,7 @@ def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path): assert module._patch_union_schema_markdown("#### BrokenUnion\n", spec_path) == "#### BrokenUnion\n" -def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monkeypatch): +def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): module = _load_generate_swagger_markdown_docs_module() spec_path = tmp_path / "console-openapi.json" output_path = tmp_path / "console-openapi.md" diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index 27bc5e341e9..067edc6fd68 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -3,11 +3,12 @@ from __future__ import annotations from datetime import UTC, datetime import pytest +from flask import Flask from controllers.console.app import message as message_module -def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_chat_messages_query_valid(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test valid ChatMessagesQuery with all fields.""" query = message_module.ChatMessagesQuery( conversation_id="550e8400-e29b-41d4-a716-446655440000", @@ -17,14 +18,14 @@ def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None assert query.limit == 50 -def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_chat_messages_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test ChatMessagesQuery with defaults.""" query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000") assert query.first_id is None assert query.limit == 20 -def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_chat_messages_query_empty_first_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test ChatMessagesQuery converts empty first_id to None.""" query = message_module.ChatMessagesQuery( conversation_id="550e8400-e29b-41d4-a716-446655440000", @@ -33,7 +34,7 @@ def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch assert query.first_id is None -def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_message_feedback_payload_valid_like(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test MessageFeedbackPayload with like rating.""" payload = message_module.MessageFeedbackPayload( message_id="550e8400-e29b-41d4-a716-446655440000", @@ -44,7 +45,7 @@ def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatc assert payload.content == "Good answer" -def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_message_feedback_payload_valid_dislike(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test MessageFeedbackPayload with dislike rating.""" payload = message_module.MessageFeedbackPayload( message_id="550e8400-e29b-41d4-a716-446655440000", @@ -53,69 +54,69 @@ def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyP assert payload.rating == "dislike" -def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_message_feedback_payload_no_rating(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test MessageFeedbackPayload without rating.""" payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000") assert payload.rating is None -def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with default format.""" query = message_module.FeedbackExportQuery() assert query.format == "csv" assert query.from_source is None -def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_json_format(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with JSON format.""" query = message_module.FeedbackExportQuery(format="json") assert query.format == "json" -def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_has_comment_true(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with has_comment as true string.""" query = message_module.FeedbackExportQuery(has_comment="true") assert query.has_comment is True -def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_has_comment_false(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with has_comment as false string.""" query = message_module.FeedbackExportQuery(has_comment="false") assert query.has_comment is False -def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_has_comment_1(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with has_comment as 1.""" query = message_module.FeedbackExportQuery(has_comment="1") assert query.has_comment is True -def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_has_comment_0(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with has_comment as 0.""" query = message_module.FeedbackExportQuery(has_comment="0") assert query.has_comment is False -def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_feedback_export_query_rating_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test FeedbackExportQuery with rating filter.""" query = message_module.FeedbackExportQuery(rating="like") assert query.rating == "like" -def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_annotation_count_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test AnnotationCountResponse creation.""" response = message_module.AnnotationCountResponse(count=10) assert response.count == 10 -def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_suggested_questions_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test SuggestedQuestionsResponse creation.""" response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" -def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_message_detail_response_normalizes_aliases_and_timestamp(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) response = message_module.MessageDetailResponse.model_validate( diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py index b31dccd034f..b0506e348fc 100644 --- a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py @@ -5,6 +5,7 @@ from inspect import unwrap from types import SimpleNamespace import pytest +from flask import Flask from werkzeug.exceptions import BadRequest from controllers.console.app import statistic as statistic_module @@ -38,7 +39,7 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) -def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_message_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() method = unwrap(api.get) @@ -52,7 +53,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]} -def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_conversation_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyConversationStatistic() method = unwrap(api.get) @@ -66,7 +67,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} -def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_token_cost_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTokenCostStatistic() method = unwrap(api.get) @@ -84,7 +85,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey assert data["data"][0]["total_price"] == 0.25 -def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_terminals_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTerminalsStatistic() method = unwrap(api.get) @@ -98,7 +99,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]} -def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_average_session_interaction_statistic_requires_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Test that AverageSessionInteractionStatistic is limited to chat/agent modes.""" # This just verifies the decorator is applied correctly # Actual endpoint testing would require complex JOIN mocking @@ -107,7 +108,7 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat assert callable(method) -def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_message_statistic_with_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() method = unwrap(api.get) @@ -123,7 +124,7 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) -def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_message_statistic_multiple_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() method = unwrap(api.get) @@ -142,7 +143,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa assert len(data["data"]) == 3 -def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_message_statistic_empty_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() method = unwrap(api.get) @@ -155,7 +156,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat assert response.get_json() == {"data": []} -def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_conversation_statistic_with_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyConversationStatistic() method = unwrap(api.get) @@ -174,7 +175,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} -def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_daily_token_cost_with_multiple_currencies(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTokenCostStatistic() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index a2748ad323e..5cc5af9592b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import PropertyMock, patch import pytest +from flask import Flask from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module from models.account import Account, TenantAccountRole @@ -117,7 +118,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( assert response["has_more"] is False -def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_rag_pipeline_workflow_patch_serializes_response_model(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow = _make_workflow(marked_name="Updated release") class _SessionContext: diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py index cfc0299cc2c..11916c87b68 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from werkzeug.exceptions import HTTPException, NotFound from controllers.console.snippets import snippet_workflow as snippet_workflow_module @@ -52,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app): view() -def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_snippet_injects_resolved_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: snippet = _snippet() @snippet_workflow_module.get_snippet @@ -72,7 +73,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat assert result is snippet -def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_snippet_raises_not_found_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @snippet_workflow_module.get_snippet def view(**kwargs): return kwargs @@ -89,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt view(snippet_id="snippet-1") -def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_draft_workflow_get_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: snippet = _snippet() monkeypatch.setattr( snippet_workflow_module, @@ -105,7 +106,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP handler(api, snippet=snippet) -def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_draft_workflow_post_returns_400_for_invalid_graph(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: user = _account("account-1") snippet = _snippet() sync_draft_workflow = Mock(side_effect=ValueError("invalid graph")) @@ -145,7 +146,7 @@ def test_published_workflow_get_returns_none_when_not_published(app) -> None: assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None -def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_published_workflow_post_returns_400_when_publish_fails(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: user = _account("account-1") snippet = _snippet() merged_snippet = _snippet() @@ -180,7 +181,7 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch session.commit.assert_not_called() -def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_default_block_configs_delegates_to_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: get_default_block_configs = Mock(return_value=[{"type": "llm"}]) monkeypatch.setattr( snippet_workflow_module, @@ -198,7 +199,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon get_default_block_configs.assert_called_once() -def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_restore_published_snippet_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow = SimpleNamespace( unique_hash="restored-hash", updated_at=None, @@ -226,7 +227,7 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p assert response["hash"] == "restored-hash" -def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_restore_published_snippet_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: user = _account("account-1") snippet = _snippet() @@ -311,7 +312,7 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra assert exc.value.description == "invalid snippet workflow graph" -def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_run_detail_raises_not_found_when_run_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: snippet = _snippet() monkeypatch.setattr( snippet_workflow_module, @@ -327,7 +328,9 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: handler(api, snippet=snippet, run_id="run-1") -def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_draft_node_last_run_raises_not_found_when_execution_missing( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: snippet = _snippet() draft_workflow = SimpleNamespace(id="workflow-1") monkeypatch.setattr( @@ -347,7 +350,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey handler(api, snippet=snippet, node_id="llm-1") -def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_task_stop_uses_queue_flag_and_graph_command(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: set_stop_flag = Mock() send_stop_command = Mock() monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py index 9a3637d2f43..e6bee6fe1d3 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py @@ -27,7 +27,7 @@ def _make_account() -> Account: @pytest.fixture(autouse=True) -def _patch_snippet_service_factory(monkeypatch): +def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch): def factory(): service_factory = module.SnippetService if isinstance(service_factory, type): @@ -64,7 +64,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable( module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1") -def test_conversation_variables_returns_empty_list(app): +def test_conversation_variables_returns_empty_list(app: Flask): api = module.SnippetConversationVariableCollectionApi() handler = _unwrap(api.get) @@ -74,7 +74,7 @@ def test_conversation_variables_returns_empty_list(app): assert result == WorkflowDraftVariableList(variables=[]) -def test_system_variables_returns_empty_list(app): +def test_system_variables_returns_empty_list(app: Flask): api = module.SnippetSystemVariableCollectionApi() handler = _unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/test_init_validate.py b/api/tests/unit_tests/controllers/console/test_init_validate.py index 3077304cbed..4954e0dc96a 100644 --- a/api/tests/unit_tests/controllers/console/test_init_validate.py +++ b/api/tests/unit_tests/controllers/console/test_init_validate.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from controllers.console import init_validate from controllers.console.error import AlreadySetupError, InitValidateFailedError @@ -35,7 +36,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None: assert result.status == "not_started" -def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1) app.secret_key = "test-secret" @@ -45,7 +46,7 @@ def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPat init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw")) -def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") @@ -57,7 +58,7 @@ def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPa assert init_validate.session.get("is_init_validated") is False -def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") @@ -74,7 +75,7 @@ def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatc assert init_validate.get_init_validate_status() is True -def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_init_validate_status_validated_session(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setenv("INIT_PASSWORD", "expected") app.secret_key = "test-secret" @@ -84,7 +85,7 @@ def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.Mon assert init_validate.get_init_validate_status() is True -def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_init_validate_status_setup_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setenv("INIT_PASSWORD", "expected") monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True)) @@ -96,7 +97,7 @@ def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPa assert init_validate.get_init_validate_status() is True -def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_init_validate_status_not_validated(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") monkeypatch.setenv("INIT_PASSWORD", "expected") monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False)) diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index e7127aef236..25245e743c5 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock import httpx import pytest +from flask import Flask from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError from controllers.console import remote_files as remote_files_module @@ -82,7 +83,7 @@ def _mock_upload_dependencies( return file_service_cls, current_user -def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_remote_file_info_uses_head_when_successful(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() handler = unwrap(api.get) decoded_url = "https://example.com/test.txt" @@ -103,7 +104,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest make_request.assert_called_once_with("HEAD", decoded_url) -def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_remote_file_info_preserves_unencoded_target_query(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() handler = unwrap(api.get) target_url = "http://example.com/api/aiagent/httpview/txt" @@ -124,7 +125,9 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: make_request.assert_called_once_with("HEAD", f"{target_url}?{query}") -def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: api = remote_files_module.GetRemoteFileInfo() handler = unwrap(api.get) decoded_url = "https://example.com/test.txt" @@ -147,7 +150,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo assert make_request.call_args_list[1].kwargs == {"timeout": 3} -def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_success_when_fetch_falls_back_to_get(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/report.txt" @@ -220,7 +223,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content" -def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/fail.txt" @@ -238,7 +241,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat handler(api, _make_account()) -def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_raises_on_httpx_request_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/fail.txt" @@ -252,7 +255,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte handler(api, _make_account()) -def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_rejects_oversized_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/large.bin" @@ -267,7 +270,9 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk handler(api, current_user) -def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_translates_service_file_too_large_error( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/large.bin" @@ -282,7 +287,9 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp handler(api, current_user) -def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_remote_file_upload_translates_service_unsupported_type_error( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: api = remote_files_module.RemoteFileUpload() handler = unwrap(api.post) url = "https://example.com/file.exe" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index c39d0930bec..e419428ca66 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -102,10 +102,10 @@ class TestChangeEmailSend: @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") def test_should_normalize_new_email_phase( self, - mock_extract_ip, - mock_is_ip_limit, - mock_send_email, - mock_get_change_data, + mock_extract_ip: MagicMock, + mock_is_ip_limit: MagicMock, + mock_send_email: MagicMock, + mock_get_change_data: MagicMock, app: Flask, ): mock_account = _build_account("current@example.com", "acc1") @@ -143,10 +143,10 @@ class TestChangeEmailSend: @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( self, - mock_extract_ip, - mock_is_ip_limit, - mock_send_email, - mock_get_change_data, + mock_extract_ip: MagicMock, + mock_is_ip_limit: MagicMock, + mock_send_email: MagicMock, + mock_get_change_data: MagicMock, app: Flask, ): """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" @@ -178,10 +178,10 @@ class TestChangeEmailSend: @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user( self, - mock_extract_ip, - mock_is_ip_limit, - mock_send_email, - mock_get_change_data, + mock_extract_ip: MagicMock, + mock_is_ip_limit: MagicMock, + mock_send_email: MagicMock, + mock_get_change_data: MagicMock, app: Flask, ): from controllers.console.auth.error import InvalidTokenError diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 96f55c85fe5..618a1f52180 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -603,7 +603,7 @@ class TestRateLimiting: @patch("controllers.console.wraps.redis_client") @patch("controllers.console.wraps.db") - def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis): + def test_should_allow_requests_within_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock): """Test that requests within rate limit are allowed""" # Arrange mock_rate_limit = MagicMock() @@ -631,7 +631,7 @@ class TestRateLimiting: @patch("controllers.console.wraps.redis_client") @patch("controllers.console.wraps.db") - def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis): + def test_should_reject_requests_over_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock): """Test that requests over rate limit are rejected and logged""" # Arrange app = create_app_with_login() @@ -720,7 +720,7 @@ class TestSystemSetup: """Test system setup decorator""" @patch("controllers.console.wraps.db") - def test_should_allow_when_setup_complete(self, mock_db): + def test_should_allow_when_setup_complete(self, mock_db: MagicMock): """Test that requests are allowed when setup is complete""" # Arrange @@ -737,7 +737,7 @@ class TestSystemSetup: @patch("controllers.console.wraps.db") @patch("controllers.console.wraps.os.environ.get") - def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): + def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db: MagicMock): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange mock_db.session.scalar.return_value = None # No setup @@ -754,7 +754,7 @@ class TestSystemSetup: @patch("controllers.console.wraps.db") @patch("controllers.console.wraps.os.environ.get") - def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): + def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db: MagicMock): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange mock_db.session.scalar.return_value = None # No setup diff --git a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py index a276a181d62..e8e005a1b83 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py @@ -3,6 +3,7 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console.workspace import snippets as snippets_module @@ -69,7 +70,7 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values(): } -def test_list_snippets_returns_pagination(app, monkeypatch): +def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.MonkeyPatch): snippets = [_snippet()] tag_id = "11111111-1111-1111-1111-111111111111" get_snippets = Mock(return_value=(snippets, 1, False)) @@ -104,7 +105,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch): ) -def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch): +def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") snippet = _snippet() create_snippet = Mock(return_value=snippet) @@ -140,7 +141,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE -def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch): +def test_create_snippet_rejects_forbidden_nodes(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") create_snippet = Mock() monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet) @@ -169,7 +170,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch): create_snippet.assert_not_called() -def test_get_snippet_detail_raises_when_missing(app, monkeypatch): +def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetDetailApi() @@ -180,7 +181,7 @@ def test_get_snippet_detail_raises_when_missing(app, monkeypatch): handler(api, "tenant-1", snippet_id="snippet-1") -def test_get_snippet_detail_returns_snippet(app, monkeypatch): +def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet() monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"})) @@ -195,7 +196,7 @@ def test_get_snippet_detail_returns_snippet(app, monkeypatch): assert response == {"id": "snippet-1"} -def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch): +def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet() user = _account("user-1") monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) @@ -214,7 +215,7 @@ def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch): assert response == {"message": "No valid fields to update"} -def test_patch_snippet_updates_and_commits(app, monkeypatch): +def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") snippet = _snippet() updated_snippet = _snippet(name="New") @@ -251,7 +252,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch): session.commit.assert_called_once() -def test_delete_snippet_deletes_and_commits(app, monkeypatch): +def test_delete_snippet_deletes_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet() session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock()) delete_snippet = Mock() @@ -277,7 +278,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch): session.commit.assert_called_once() -def test_export_snippet_returns_yaml_attachment(app, monkeypatch): +def test_export_snippet_returns_yaml_attachment(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet(name="Snippet One") export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n") session = SimpleNamespace() @@ -308,7 +309,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch): export_snippet_dsl.assert_called_once_with(snippet=snippet, include_secret=True) -def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): +def test_import_snippet_returns_202_for_pending_confirmation(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0") import_snippet = Mock(return_value=result) @@ -348,7 +349,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): session.commit.assert_called_once() -def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): +def test_import_snippet_returns_400_for_failed_import(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL") import_snippet = Mock(return_value=result) @@ -381,7 +382,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): session.commit.assert_called_once() -def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): +def test_import_confirm_returns_200_for_completed_import(app: Flask, monkeypatch: pytest.MonkeyPatch): user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1") confirm_import = Mock(return_value=result) @@ -414,7 +415,7 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): session.commit.assert_called_once() -def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch): +def test_check_dependencies_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetCheckDependenciesApi() @@ -425,7 +426,7 @@ def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch): handler(api, "tenant-1", snippet_id="snippet-1") -def test_check_dependencies_returns_dependency_result(app, monkeypatch): +def test_check_dependencies_returns_dependency_result(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet() check_dependencies = Mock( return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []})) @@ -456,7 +457,7 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch): check_dependencies.assert_called_once_with(snippet=snippet) -def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch): +def test_increment_use_count_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetUseCountIncrementApi() @@ -470,7 +471,7 @@ def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch): handler(api, "tenant-1", snippet_id="snippet-1") -def test_increment_use_count_returns_refreshed_count(app, monkeypatch): +def test_increment_use_count_returns_refreshed_count(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=2) merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=3) session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock(), refresh=Mock()) diff --git a/api/tests/unit_tests/controllers/openapi/conftest.py b/api/tests/unit_tests/controllers/openapi/conftest.py index 70302810d2c..d79b5bd642c 100644 --- a/api/tests/unit_tests/controllers/openapi/conftest.py +++ b/api/tests/unit_tests/controllers/openapi/conftest.py @@ -35,7 +35,7 @@ def _stub_execute( @pytest.fixture -def bypass_pipeline(monkeypatch): +def bypass_pipeline(monkeypatch: pytest.MonkeyPatch): """Stub PipelineRouter._execute so endpoints skip real auth at request time. Module-level @auth_router.guard(...) captures the real router at import diff --git a/api/tests/unit_tests/controllers/openapi/test_account.py b/api/tests/unit_tests/controllers/openapi/test_account.py index 5bab035e457..b4af7dda4a6 100644 --- a/api/tests/unit_tests/controllers/openapi/test_account.py +++ b/api/tests/unit_tests/controllers/openapi/test_account.py @@ -167,14 +167,14 @@ def _session_auth_data() -> AuthData: ) -def _stub_session_deps(monkeypatch, rows): +def _stub_session_deps(monkeypatch: pytest.MonkeyPatch, rows): mod = sys.modules[_ACCOUNT_MOD] monkeypatch.setattr(mod, "get_auth_ctx", lambda: SimpleNamespace()) monkeypatch.setattr(mod, "list_active_sessions", lambda *args, **kwargs: rows) monkeypatch.setattr(mod, "db", MagicMock()) -def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch): +def test_sessions_list_valid_query_parses_page_and_limit(app: Flask, monkeypatch: pytest.MonkeyPatch): """A valid ?page&limit round-trips through SessionListQuery into the response envelope.""" api = AccountSessionsApi() _stub_session_deps(monkeypatch, []) @@ -187,7 +187,7 @@ def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch): assert body["data"] == [] -def test_sessions_list_defaults_when_query_omitted(app, monkeypatch): +def test_sessions_list_defaults_when_query_omitted(app: Flask, monkeypatch: pytest.MonkeyPatch): """No query → the model's defaults (page=1, limit=100) drive the envelope.""" api = AccountSessionsApi() _stub_session_deps(monkeypatch, []) @@ -209,7 +209,7 @@ def test_sessions_list_defaults_when_query_omitted(app, monkeypatch): "foo=bar", # extra='forbid' ], ) -def test_sessions_list_rejects_out_of_bounds_query(app, monkeypatch, query): +def test_sessions_list_rejects_out_of_bounds_query(app: Flask, monkeypatch: pytest.MonkeyPatch, query): """Out-of-range / unknown query params raise 422 instead of being silently coerced.""" api = AccountSessionsApi() _stub_session_deps(monkeypatch, []) diff --git a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py index 0dbb595ba11..ddd72f604d6 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py @@ -6,6 +6,9 @@ import sys from types import SimpleNamespace from unittest.mock import Mock +import pytest +from flask import Flask + from controllers.openapi._models import AppRunRequest @@ -30,7 +33,9 @@ def test_app_run_request_with_query(): assert req.query == "hello" -def test_run_chat_always_calls_generate_with_streaming_true(app, bypass_pipeline, monkeypatch): +def test_run_chat_always_calls_generate_with_streaming_true( + app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch +): """_run_chat must always invoke AppGenerateService.generate with streaming=True.""" from controllers.openapi.app_run import _run_chat @@ -56,7 +61,7 @@ def test_stop_task_endpoint_registered(openapi_app): assert "/openapi/v1/apps//tasks//stop" in rules -def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch): +def test_stop_task_calls_queue_manager_and_graph_engine(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): import uuid from controllers.openapi.app_run import AppRunTaskStopApi diff --git a/api/tests/unit_tests/controllers/openapi/test_meta_version.py b/api/tests/unit_tests/controllers/openapi/test_meta_version.py index 8f9e1016e8e..57de0c517f4 100644 --- a/api/tests/unit_tests/controllers/openapi/test_meta_version.py +++ b/api/tests/unit_tests/controllers/openapi/test_meta_version.py @@ -2,6 +2,8 @@ from __future__ import annotations +import pytest + def test_version_endpoint_returns_200_without_auth(openapi_app): client = openapi_app.test_client() @@ -30,7 +32,7 @@ def test_version_endpoint_ignores_bearer_header(openapi_app): assert "edition" in payload -def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch): +def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch: pytest.MonkeyPatch): from configs import dify_config monkeypatch.setattr(dify_config, "EDITION", "CLOUD") @@ -42,7 +44,7 @@ def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch): assert response.get_json()["edition"] == "CLOUD" -def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch): +def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch: pytest.MonkeyPatch): from configs import dify_config monkeypatch.setattr(dify_config, "EDITION", "EXPERIMENTAL") diff --git a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py index 78f2d0f20d0..51d5ecdd36f 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py +++ b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py @@ -8,6 +8,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.openapi.auth.data import AuthData @@ -51,7 +52,7 @@ class TestOpenApiWorkflowEventsApi: return OpenApiWorkflowEventsApi() - def test_not_found_when_run_missing(self, app, bypass_pipeline, monkeypatch): + def test_not_found_when_run_missing(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): module = sys.modules["controllers.openapi.workflow_events"] repo_mock = Mock() repo_mock.get_workflow_run_by_id_and_tenant_id.return_value = None @@ -76,7 +77,9 @@ class TestOpenApiWorkflowEventsApi: auth_data=_make_auth_data(app_model, caller, "account"), ) - def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch): + def test_not_found_when_run_belongs_to_different_app( + self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch + ): module = sys.modules["controllers.openapi.workflow_events"] run = _make_workflow_run(app_id="other-app") repo_mock = Mock() @@ -102,7 +105,9 @@ class TestOpenApiWorkflowEventsApi: auth_data=_make_auth_data(app_model, caller, "account"), ) - def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch): + def test_account_caller_checks_created_by_account( + self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch + ): """Account caller must match created_by == caller.id and role == ACCOUNT.""" module = sys.modules["controllers.openapi.workflow_events"] run = _make_workflow_run(created_by_role=CreatorUserRole.ACCOUNT, created_by="acct-1") @@ -141,7 +146,9 @@ class TestOpenApiWorkflowEventsApi: ) assert resp.mimetype == "text/event-stream" - def test_account_caller_rejected_for_end_user_run(self, app, bypass_pipeline, monkeypatch): + def test_account_caller_rejected_for_end_user_run( + self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch + ): module = sys.modules["controllers.openapi.workflow_events"] run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1") repo_mock = Mock() @@ -167,7 +174,9 @@ class TestOpenApiWorkflowEventsApi: auth_data=_make_auth_data(app_model, caller, "account"), ) - def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch): + def test_end_user_caller_checks_created_by_end_user( + self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch + ): """End-user caller must match created_by == caller.id and role == END_USER.""" module = sys.modules["controllers.openapi.workflow_events"] run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1") @@ -202,7 +211,7 @@ class TestOpenApiWorkflowEventsApi: ) assert resp.mimetype == "text/event-stream" - def test_finished_run_returns_single_sse_event(self, app, bypass_pipeline, monkeypatch): + def test_finished_run_returns_single_sse_event(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """A finished run returns a single done-event SSE response without streaming.""" from datetime import UTC, datetime diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py index 4c09491ab59..cf9fa671987 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py @@ -34,6 +34,7 @@ from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity from controllers.openapi import bp as openapi_bp from controllers.openapi._errors import MemberLicenseExceeded, MemberLimitExceeded from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload +from controllers.openapi.auth.data import AuthData from controllers.openapi.workspaces import ( WorkspaceMemberApi, WorkspaceMemberRoleApi, @@ -228,7 +229,7 @@ def test_role_payload_rejects_extra_field(): MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"}) -def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline): +def test_invite_rejects_invalid_body_with_422(app: Flask, bypass_pipeline): """Invalid invite body → 422 via @accepts (was 400 through _validate_body).""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -245,7 +246,7 @@ def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline): api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) -def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline): +def test_update_role_rejects_invalid_body_with_422(app: Flask, bypass_pipeline): """Invalid role-update body surfaces as 422 through @accepts (was 400).""" ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -267,7 +268,9 @@ def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline): # --------------------------------------------------------------------------- -def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, monkeypatch): +def test_switch_returns_workspace_detail_with_current_true( + app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch +): """Happy path: switch service is called, then the workspace+membership row is re-queried so the returned `current` reflects post-commit state. """ @@ -298,7 +301,9 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, assert switch_mock.called -def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pipeline, monkeypatch): +def test_switch_404s_when_service_raises_account_not_link_tenant( + app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch +): """If switch_tenant raises (e.g. Tenant.status != NORMAL), the body surfaces as NotFound, not 500.""" ws_id = str(uuid.uuid4()) @@ -326,7 +331,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip # --------------------------------------------------------------------------- -def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch): +def test_members_list_returns_normalized_rows(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMembersApi() @@ -364,7 +369,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch) assert body["data"][0]["status"] == "active" -def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypatch): +def test_members_list_paginates_with_query_params(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """`?page=2&limit=2` slices service output and reports total/has_more.""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -404,7 +409,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa assert [d["id"] for d in body["data"]] == ["m-2", "m-3"] -def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch): +def test_members_list_rejects_unknown_query_param(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts).""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -425,7 +430,9 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa # --------------------------------------------------------------------------- -def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline, monkeypatch): +def test_invite_happy_path_returns_invite_url_and_member_id( + app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch +): ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMembersApi() @@ -507,7 +514,7 @@ def _invite_request(app, ws_id: str, acct_id: uuid.UUID): ) -def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): +def test_invite_blocked_by_saas_members_cap(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """SaaS billing plan member cap → MemberLimitExceeded (403).""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -541,7 +548,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): invite_mock.assert_not_called() -def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, monkeypatch): +def test_invite_blocked_by_ee_workspace_members_license(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """EE License workspace_members cap → MemberLicenseExceeded (403). Note: billing.enabled is False (EE without SaaS billing); only the @@ -583,7 +590,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo invite_mock.assert_not_called() -def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypatch): +def test_invite_ce_passes_when_both_caps_disabled(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """CE deployment (no billing, no license) → quota gate is a no-op, invite proceeds normally.""" ws_id = str(uuid.uuid4()) @@ -619,7 +626,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa assert body["email"] == "new@example.com" -def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch): +def test_invite_400_when_already_in_tenant(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMembersApi() @@ -650,7 +657,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch): # --------------------------------------------------------------------------- -def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch): +def test_delete_member_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMemberApi() @@ -692,7 +699,7 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch): (MemberNotInTenantError("not in tenant"), NotFound), ], ) -def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected): +def test_delete_member_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected): ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMemberApi() @@ -725,7 +732,7 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, ) -def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch): +def test_delete_member_404_when_member_missing(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMemberApi() @@ -757,7 +764,7 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch # --------------------------------------------------------------------------- -def test_update_role_happy_path(app, bypass_pipeline, monkeypatch): +def test_update_role_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMemberRoleApi() @@ -801,7 +808,7 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch): (MemberNotInTenantError("not in tenant"), NotFound), ], ) -def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected): +def test_update_role_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected): ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMemberRoleApi() @@ -841,7 +848,7 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e # --------------------------------------------------------------------------- -def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatch): +def test_load_tenant_rejects_archived_workspace(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """Member management against an archived workspace → 404.""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() @@ -869,7 +876,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc # --------------------------------------------------------------------------- -def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch): +def test_invite_400_when_register_error(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): """AccountRegisterError (frozen email, workspace creation blocked) → 400.""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() diff --git a/api/tests/unit_tests/core/agent/strategy/test_plugin.py b/api/tests/unit_tests/core/agent/strategy/test_plugin.py index 0fea04845d2..15c441e5876 100644 --- a/api/tests/unit_tests/core/agent/strategy/test_plugin.py +++ b/api/tests/unit_tests/core/agent/strategy/test_plugin.py @@ -81,7 +81,7 @@ class TestPluginAgentStrategyInitialization: class TestGetParameters: - def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None: + def test_get_parameters_returns_parameters(self, strategy: PluginAgentStrategy, mock_declaration) -> None: result = strategy.get_parameters() assert result == mock_declaration.parameters @@ -92,7 +92,7 @@ class TestGetParameters: class TestInitializeParameters: - def test_initialize_parameters_success(self, strategy, mock_declaration) -> None: + def test_initialize_parameters_success(self, strategy: PluginAgentStrategy, mock_declaration) -> None: params = {"param1": "value1"} result = strategy.initialize_parameters(params.copy()) @@ -114,13 +114,13 @@ class TestInitializeParameters: {"param1": {}, "param2": "value"}, ], ) - def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None: + def test_initialize_parameters_edge_cases(self, strategy: PluginAgentStrategy, input_params) -> None: result = strategy.initialize_parameters(input_params.copy()) for param in strategy.declaration.parameters: assert param.name in result - def test_initialize_parameters_invalid_input_type(self, strategy) -> None: + def test_initialize_parameters_invalid_input_type(self, strategy: PluginAgentStrategy) -> None: with pytest.raises(AttributeError): strategy.initialize_parameters(None) @@ -131,7 +131,7 @@ class TestInitializeParameters: class TestInvoke: - def test_invoke_success_all_arguments(self, strategy, mocker) -> None: + def test_invoke_success_all_arguments(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None: mock_manager = MagicMock() mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"])) @@ -171,7 +171,7 @@ class TestInvoke: assert call_kwargs["message_id"] == "msg_1" assert call_kwargs["context"] is not None - def test_invoke_with_credentials(self, strategy, mocker) -> None: + def test_invoke_with_credentials(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None: mock_manager = MagicMock() mock_manager.invoke = MagicMock(return_value=iter([])) @@ -243,7 +243,7 @@ class TestInvoke: assert result == [] mock_manager.invoke.assert_called_once() - def test_invoke_convert_raises_exception(self, strategy, mocker) -> None: + def test_invoke_convert_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None: mocker.patch( "core.agent.strategy.plugin.PluginAgentClient", return_value=MagicMock(), @@ -257,7 +257,7 @@ class TestInvoke: with pytest.raises(ValueError): list(strategy._invoke(params={}, user_id="user_1")) - def test_invoke_manager_raises_exception(self, strategy, mocker) -> None: + def test_invoke_manager_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None: mock_manager = MagicMock() mock_manager.invoke.side_effect = RuntimeError("invoke failed") diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index a1af0a87a55..b983a945ad5 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -42,13 +42,13 @@ def runner(mocker: MockerFixture, mock_db_session): class TestRepack: - def test_sets_empty_if_none(self, runner, mocker: MockerFixture): + def test_sets_empty_if_none(self, runner: BaseAgentRunner, mocker: MockerFixture): entity = mocker.MagicMock() entity.app_config.prompt_template.simple_prompt_template = None result = runner._repack_app_generate_entity(entity) assert result.app_config.prompt_template.simple_prompt_template == "" - def test_keeps_existing(self, runner, mocker: MockerFixture): + def test_keeps_existing(self, runner: BaseAgentRunner, mocker: MockerFixture): entity = mocker.MagicMock() entity.app_config.prompt_template.simple_prompt_template = "abc" result = runner._repack_app_generate_entity(entity) @@ -61,7 +61,7 @@ class TestRepack: class TestUpdatePromptTool: - def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner, mocker: MockerFixture): + def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner: BaseAgentRunner, mocker: MockerFixture): tool = mocker.MagicMock() schema = { "type": "object", @@ -83,7 +83,7 @@ class TestUpdatePromptTool: class TestCreateAgentThought: - def test_with_files(self, runner, mock_db_session, mocker: MockerFixture): + def test_with_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_thought = mocker.MagicMock(id=10) mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) @@ -91,7 +91,7 @@ class TestCreateAgentThought: assert result == "10" assert runner.agent_thought_count == 1 - def test_without_files(self, runner, mock_db_session, mocker: MockerFixture): + def test_without_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_thought = mocker.MagicMock(id=11) mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) @@ -112,12 +112,12 @@ class TestSaveAgentThought: agent.thought = "" return agent - def test_not_found(self, runner, mock_db_session): + def test_not_found(self, runner: BaseAgentRunner, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(ValueError): runner.save_agent_thought("id", None, None, None, None, None, None, [], None) - def test_full_update(self, runner, mock_db_session, mocker: MockerFixture): + def test_full_update(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = self.setup_agent(mocker) mock_db_session.scalar.return_value = agent @@ -152,7 +152,7 @@ class TestSaveAgentThought: assert agent.tokens == 3 assert "tool1" in json.loads(agent.tool_labels_str) - def test_label_fallback_when_none(self, runner, mock_db_session, mocker: MockerFixture): + def test_label_fallback_when_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = self.setup_agent(mocker) agent.tool = "unknown_tool" mock_db_session.scalar.return_value = agent @@ -162,7 +162,7 @@ class TestSaveAgentThought: labels = json.loads(agent.tool_labels_str) assert "unknown_tool" in labels - def test_json_failure_paths(self, runner, mock_db_session, mocker: MockerFixture): + def test_json_failure_paths(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = self.setup_agent(mocker) mock_db_session.scalar.return_value = agent @@ -183,13 +183,13 @@ class TestSaveAgentThought: assert mock_db_session.commit.called - def test_messages_ids_none(self, runner, mock_db_session, mocker: MockerFixture): + def test_messages_ids_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = self.setup_agent(mocker) mock_db_session.scalar.return_value = agent runner.save_agent_thought("id", None, None, None, None, None, None, None, None) assert mock_db_session.commit.called - def test_success_dict_serialization(self, runner, mock_db_session, mocker: MockerFixture): + def test_success_dict_serialization(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = self.setup_agent(mocker) mock_db_session.scalar.return_value = agent @@ -215,19 +215,19 @@ class TestSaveAgentThought: class TestOrganizeUserPrompt: - def test_no_files(self, runner, mock_db_session, mocker: MockerFixture): + def test_no_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_db_session.scalars.return_value.all.return_value = [] msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) result = runner.organize_agent_user_prompt(msg) assert result.content == "hello" - def test_with_files_no_config(self, runner, mock_db_session, mocker: MockerFixture): + def test_with_files_no_config(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) result = runner.organize_agent_user_prompt(msg) assert result.content == "hello" - def test_image_detail_low_fallback(self, runner, mock_db_session, mocker: MockerFixture): + def test_image_detail_low_fallback(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] file_config = mocker.MagicMock() file_config.image_config = mocker.MagicMock(detail=None) @@ -247,27 +247,27 @@ class TestOrganizeUserPrompt: class TestOrganizeHistory: - def test_empty(self, runner, mock_db_session, mocker: MockerFixture): + def test_empty(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] mocker.patch.object(module, "extract_thread_messages", return_value=[]) result = runner.organize_agent_history([]) assert result == [] - def test_with_answer_only(self, runner, mock_db_session, mocker: MockerFixture): + def test_with_answer_only(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None) mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) result = runner.organize_agent_history([]) assert any(isinstance(x, module.AssistantPromptMessage) for x in result) - def test_skip_current_message(self, runner, mock_db_session, mocker: MockerFixture): + def test_skip_current_message(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None) mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) result = runner.organize_agent_history([]) assert result == [] - def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker: MockerFixture): + def test_with_tool_calls_invalid_json(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): thought = mocker.MagicMock( tool="tool1", tool_input="invalid", @@ -283,7 +283,7 @@ class TestOrganizeHistory: result = runner.organize_agent_history([]) assert isinstance(result, list) - def test_empty_tool_name_split(self, runner, mock_db_session, mocker: MockerFixture): + def test_empty_tool_name_split(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): thought = mocker.MagicMock(tool=";", thought="thinking") msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None) @@ -292,7 +292,7 @@ class TestOrganizeHistory: result = runner.organize_agent_history([]) assert isinstance(result, list) - def test_valid_json_tool_flow(self, runner, mock_db_session, mocker: MockerFixture): + def test_valid_json_tool_flow(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): thought = mocker.MagicMock( tool="tool1", tool_input=json.dumps({"tool1": {"x": 1}}), @@ -321,7 +321,7 @@ class TestOrganizeHistory: class TestConvertToolToPromptMessageTool: - def test_basic_conversion(self, runner, mocker: MockerFixture): + def test_basic_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture): tool = mocker.MagicMock(tool_name="tool1") tool_entity = mocker.MagicMock() @@ -347,7 +347,7 @@ class TestConvertToolToPromptMessageTool: class TestInitPromptToolsExtended: - def test_agent_tool_branch(self, runner, mocker: MockerFixture): + def test_agent_tool_branch(self, runner: BaseAgentRunner, mocker: MockerFixture): agent_tool = mocker.MagicMock(tool_name="agent_tool") runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity")) @@ -355,7 +355,7 @@ class TestInitPromptToolsExtended: tools, prompts = runner._init_prompt_tools() assert "agent_tool" in tools - def test_exception_in_conversion(self, runner, mocker: MockerFixture): + def test_exception_in_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture): agent_tool = mocker.MagicMock(tool_name="bad_tool") runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception) @@ -370,7 +370,7 @@ class TestInitPromptToolsExtended: class TestAdditionalCoverage: - def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture): + def test_save_agent_thought_existing_labels(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = mocker.MagicMock() agent.tool = "tool1" agent.tool_labels = {"tool1": {"en_US": "existing"}} @@ -381,7 +381,7 @@ class TestAdditionalCoverage: labels = json.loads(agent.tool_labels_str) assert labels["tool1"]["en_US"] == "existing" - def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker: MockerFixture): + def test_save_agent_thought_tool_meta_string(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): agent = mocker.MagicMock() agent.tool = "tool1" agent.tool_labels = {} @@ -391,7 +391,7 @@ class TestAdditionalCoverage: runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None) assert agent.tool_meta_str == "meta_string" - def test_convert_dataset_retriever_tool(self, runner, mocker: MockerFixture): + def test_convert_dataset_retriever_tool(self, runner: BaseAgentRunner, mocker: MockerFixture): ds_tool = mocker.MagicMock() ds_tool.entity.identity.name = "ds" ds_tool.entity.description.llm = "desc" @@ -408,7 +408,9 @@ class TestAdditionalCoverage: prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) assert prompt is not None - def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker: MockerFixture): + def test_organize_user_prompt_with_file_objects( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] file_config = mocker.MagicMock() @@ -427,7 +429,7 @@ class TestAdditionalCoverage: result = runner.organize_agent_user_prompt(msg) assert result is not None - def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker: MockerFixture): + def test_organize_history_without_tool_names(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture): thought = mocker.MagicMock(tool=None, thought="thinking") msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None) @@ -437,7 +439,9 @@ class TestAdditionalCoverage: result = runner.organize_agent_history([]) assert isinstance(result, list) - def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker: MockerFixture): + def test_organize_history_multiple_tools_split( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): thought = mocker.MagicMock( tool="tool1;tool2", tool_input=json.dumps({"tool1": {}, "tool2": {}}), @@ -455,7 +459,7 @@ class TestAdditionalCoverage: class TestConvertDatasetRetrieverTool: - def test_required_param_added(self, runner, mocker: MockerFixture): + def test_required_param_added(self, runner: BaseAgentRunner, mocker: MockerFixture): ds_tool = mocker.MagicMock() ds_tool.entity.identity.name = "ds" ds_tool.entity.description.llm = "desc" @@ -518,7 +522,7 @@ class TestBaseAgentRunnerInit: class TestBaseAgentRunnerCoverage: - def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture): + def test_init_prompt_tools_adds_dataset_tools(self, runner: BaseAgentRunner, mocker: MockerFixture): dataset_tool = mocker.MagicMock() dataset_tool.entity.identity.name = "ds" runner.dataset_tools = [dataset_tool] @@ -530,7 +534,9 @@ class TestBaseAgentRunnerCoverage: assert tools["ds"] == dataset_tool assert len(prompt_tools) == 1 - def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture): + def test_save_agent_thought_json_dumps_fallbacks( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): agent = mocker.MagicMock() agent.tool = "tool1" agent.tool_labels = {} @@ -568,7 +574,9 @@ class TestBaseAgentRunnerCoverage: assert isinstance(agent.observation, str) assert isinstance(agent.tool_meta_str, str) - def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker: MockerFixture): + def test_save_agent_thought_skips_empty_tool_name( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): agent = mocker.MagicMock() agent.tool = "tool1;;" agent.tool_labels = {} @@ -582,7 +590,9 @@ class TestBaseAgentRunnerCoverage: labels = json.loads(agent.tool_labels_str) assert "" not in labels - def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker: MockerFixture): + def test_organize_history_includes_system_prompt( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] mocker.patch.object(module, "extract_thread_messages", return_value=[]) @@ -592,7 +602,9 @@ class TestBaseAgentRunnerCoverage: assert system_message in result - def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker: MockerFixture): + def test_organize_history_tool_inputs_and_observation_none( + self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture + ): thought = mocker.MagicMock( tool="tool1", tool_input=None, diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index 314305d371a..a6cae351b17 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -95,25 +95,25 @@ class TestFillInputs: ("", {"x": "y"}, ""), ], ) - def test_fill_in_inputs(self, runner, instruction, inputs, expected): + def test_fill_in_inputs(self, runner: DummyRunner, instruction, inputs, expected): result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs) assert result == expected class TestConvertDictToAction: - def test_convert_valid_dict(self, runner): + def test_convert_valid_dict(self, runner: DummyRunner): action_dict = {"action": "test", "action_input": {"a": 1}} action = runner._convert_dict_to_action(action_dict) assert action.action_name == "test" assert action.action_input == {"a": 1} - def test_convert_missing_keys(self, runner): + def test_convert_missing_keys(self, runner: DummyRunner): with pytest.raises(KeyError): runner._convert_dict_to_action({"invalid": 1}) class TestFormatAssistantMessage: - def test_format_assistant_message_multiple_scratchpads(self, runner): + def test_format_assistant_message_multiple_scratchpads(self, runner: DummyRunner): sp1 = AgentScratchpadUnit( agent_response="resp1", thought="thought1", @@ -131,7 +131,7 @@ class TestFormatAssistantMessage: result = runner._format_assistant_message([sp1, sp2]) assert "Final Answer:" in result - def test_format_with_final(self, runner): + def test_format_with_final(self, runner: DummyRunner): scratchpad = AgentScratchpadUnit( agent_response="Done", thought="", @@ -144,7 +144,7 @@ class TestFormatAssistantMessage: result = runner._format_assistant_message([scratchpad]) assert "Final Answer" in result - def test_format_with_action_and_observation(self, runner): + def test_format_with_action_and_observation(self, runner: DummyRunner): scratchpad = AgentScratchpadUnit( agent_response="resp", thought="thinking", @@ -161,12 +161,12 @@ class TestFormatAssistantMessage: class TestHandleInvokeAction: - def test_handle_invoke_action_tool_not_present(self, runner): + def test_handle_invoke_action_tool_not_present(self, runner: DummyRunner): action = AgentScratchpadUnit.Action(action_name="missing", action_input={}) response, meta = runner._handle_invoke_action(action, {}, []) assert "there is not a tool named" in response - def test_tool_with_json_string_args(self, runner, mocker: MockerFixture): + def test_tool_with_json_string_args(self, runner: DummyRunner, mocker: MockerFixture): action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1})) tool_instance = MagicMock() tool_instances = {"tool": tool_instance} @@ -181,7 +181,7 @@ class TestHandleInvokeAction: class TestOrganizeHistoricPromptMessages: - def test_empty_history(self, runner, mocker: MockerFixture): + def test_empty_history(self, runner: DummyRunner, mocker: MockerFixture): mocker.patch( "core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt", return_value=[], @@ -191,7 +191,7 @@ class TestOrganizeHistoricPromptMessages: class TestRun: - def test_run_handles_empty_parser_output(self, runner, mocker: MockerFixture): + def test_run_handles_empty_parser_output(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -203,7 +203,7 @@ class TestRun: results = list(runner.run(message, "query", {})) assert isinstance(results, list) - def test_run_with_action_and_tool_invocation(self, runner, mocker: MockerFixture): + def test_run_with_action_and_tool_invocation(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -224,7 +224,7 @@ class TestRun: with pytest.raises(AgentMaxIterationError): list(runner.run(message, "query", {"tool": MagicMock()})) - def test_run_respects_max_iteration_boundary(self, runner, mocker: MockerFixture): + def test_run_respects_max_iteration_boundary(self, runner: DummyRunner, mocker: MockerFixture): runner.app_config.agent.max_iteration = 1 message = MagicMock() message.id = "msg-id" @@ -246,7 +246,7 @@ class TestRun: with pytest.raises(AgentMaxIterationError): list(runner.run(message, "query", {"tool": MagicMock()})) - def test_run_basic_flow(self, runner, mocker: MockerFixture): + def test_run_basic_flow(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -258,7 +258,7 @@ class TestRun: results = list(runner.run(message, "query", {"name": "John"})) assert results - def test_run_max_iteration_error(self, runner, mocker: MockerFixture): + def test_run_max_iteration_error(self, runner: DummyRunner, mocker: MockerFixture): runner.app_config.agent.max_iteration = 0 message = MagicMock() message.id = "msg-id" @@ -273,7 +273,7 @@ class TestRun: with pytest.raises(AgentMaxIterationError): list(runner.run(message, "query", {})) - def test_run_increase_usage_aggregation(self, runner, mocker: MockerFixture): + def test_run_increase_usage_aggregation(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" runner.app_config.agent.max_iteration = 2 @@ -330,7 +330,7 @@ class TestRun: assert final_usage.completion_price == 2 assert final_usage.total_price == 4 - def test_run_when_no_action_branch(self, runner, mocker: MockerFixture): + def test_run_when_no_action_branch(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -342,7 +342,7 @@ class TestRun: results = list(runner.run(message, "query", {})) assert results[-1].delta.message.content == "" - def test_run_usage_missing_key_branch(self, runner, mocker: MockerFixture): + def test_run_usage_missing_key_branch(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -355,7 +355,7 @@ class TestRun: list(runner.run(message, "query", {})) - def test_run_prompt_tool_update_branch(self, runner, mocker: MockerFixture): + def test_run_prompt_tool_update_branch(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -387,7 +387,7 @@ class TestRun: runner.update_prompt_message_tool.assert_called_once() - def test_historic_with_assistant_and_tool_calls(self, runner): + def test_historic_with_assistant_and_tool_calls(self, runner: DummyRunner): from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage assistant = AssistantPromptMessage(content="thinking") @@ -400,7 +400,7 @@ class TestRun: result = runner._organize_historic_prompt_messages([]) assert isinstance(result, list) - def test_historic_final_flush_branch(self, runner): + def test_historic_final_flush_branch(self, runner: DummyRunner): from graphon.model_runtime.entities.message_entities import AssistantPromptMessage assistant = AssistantPromptMessage(content="final") @@ -411,7 +411,7 @@ class TestRun: class TestInitReactState: - def test_init_react_state_resets_state(self, runner, mocker: MockerFixture): + def test_init_react_state_resets_state(self, runner: DummyRunner, mocker: MockerFixture): mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"]) runner._agent_scratchpad = ["old"] runner._query = "old" @@ -424,7 +424,7 @@ class TestInitReactState: class TestHandleInvokeActionExtended: - def test_tool_with_invalid_json_string_args(self, runner, mocker: MockerFixture): + def test_tool_with_invalid_json_string_args(self, runner: DummyRunner, mocker: MockerFixture): action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json") tool_instance = MagicMock() tool_instances = {"tool": tool_instance} @@ -443,11 +443,11 @@ class TestHandleInvokeActionExtended: class TestFillInputsEdgeCases: - def test_fill_inputs_with_empty_inputs(self, runner): + def test_fill_inputs_with_empty_inputs(self, runner: DummyRunner): result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {}) assert result == "Hello {{x}}" - def test_fill_inputs_with_exception_in_replace(self, runner): + def test_fill_inputs_with_exception_in_replace(self, runner: DummyRunner): class BadValue: def __str__(self): raise Exception("fail") @@ -458,7 +458,7 @@ class TestFillInputsEdgeCases: class TestOrganizeHistoricPromptMessagesExtended: - def test_user_message_flushes_scratchpad(self, runner, mocker: MockerFixture): + def test_user_message_flushes_scratchpad(self, runner: DummyRunner, mocker: MockerFixture): from graphon.model_runtime.entities.message_entities import UserPromptMessage user_message = UserPromptMessage(content="Hi") @@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended: result = runner._organize_historic_prompt_messages([]) assert result == ["final"] - def test_tool_message_without_scratchpad_raises(self, runner): + def test_tool_message_without_scratchpad_raises(self, runner: DummyRunner): from graphon.model_runtime.entities.message_entities import ToolPromptMessage runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] @@ -481,7 +481,7 @@ class TestOrganizeHistoricPromptMessagesExtended: with pytest.raises(NotImplementedError): runner._organize_historic_prompt_messages([]) - def test_agent_history_transform_invocation(self, runner, mocker: MockerFixture): + def test_agent_history_transform_invocation(self, runner: DummyRunner, mocker: MockerFixture): mock_transform = MagicMock() mock_transform.get_prompt.return_value = [] @@ -496,7 +496,7 @@ class TestOrganizeHistoricPromptMessagesExtended: class TestRunAdditionalBranches: - def test_run_with_no_action_final_answer_empty(self, runner, mocker: MockerFixture): + def test_run_with_no_action_final_answer_empty(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -508,7 +508,7 @@ class TestRunAdditionalBranches: results = list(runner.run(message, "query", {})) assert any(hasattr(r, "delta") for r in results) - def test_run_with_final_answer_action_string(self, runner, mocker: MockerFixture): + def test_run_with_final_answer_action_string(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -522,7 +522,7 @@ class TestRunAdditionalBranches: results = list(runner.run(message, "query", {})) assert results[-1].delta.message.content == "done" - def test_run_with_final_answer_action_dict(self, runner, mocker: MockerFixture): + def test_run_with_final_answer_action_dict(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" @@ -536,7 +536,7 @@ class TestRunAdditionalBranches: results = list(runner.run(message, "query", {})) assert json.loads(results[-1].delta.message.content) == {"a": 1} - def test_run_with_string_final_answer(self, runner, mocker: MockerFixture): + def test_run_with_string_final_answer(self, runner: DummyRunner, mocker: MockerFixture): message = MagicMock() message.id = "msg-id" diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index e79ea549a74..e910c59d368 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -40,7 +40,11 @@ def runner(mocker: MockerFixture, dummy_tool_factory): class TestOrganizeInstructionPrompt: def test_success_all_placeholders( - self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + self, + runner: CotCompletionAgentRunner, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, ): template = ( "{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}" @@ -58,12 +62,14 @@ class TestOrganizeInstructionPrompt: tools_payload = json.loads(result.split(" | ")[1]) assert {item["name"] for item in tools_payload} == {"toolA", "toolB"} - def test_agent_none_raises(self, runner, dummy_app_config_factory): + def test_agent_none_raises(self, runner: CotCompletionAgentRunner, dummy_app_config_factory): runner.app_config = dummy_app_config_factory(agent=None) with pytest.raises(ValueError, match="Agent configuration is not set"): runner._organize_instruction_prompt() - def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory): + def test_prompt_entity_none_raises( + self, runner: CotCompletionAgentRunner, dummy_app_config_factory, dummy_agent_config_factory + ): runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None)) with pytest.raises(ValueError, match="prompt entity is not set"): runner._organize_instruction_prompt() @@ -75,7 +81,7 @@ class TestOrganizeInstructionPrompt: class TestOrganizeHistoricPrompt: - def test_with_user_and_assistant_string(self, runner, mocker: MockerFixture): + def test_with_user_and_assistant_string(self, runner: CotCompletionAgentRunner, mocker: MockerFixture): user_msg = UserPromptMessage(content="Hello") assistant_msg = AssistantPromptMessage(content="Hi there") @@ -90,7 +96,7 @@ class TestOrganizeHistoricPrompt: assert "Question: Hello" in result assert "Hi there" in result - def test_assistant_list_with_text_content(self, runner, mocker: MockerFixture): + def test_assistant_list_with_text_content(self, runner: CotCompletionAgentRunner, mocker: MockerFixture): text_content = TextPromptMessageContent(data="Partial answer") assistant_msg = AssistantPromptMessage(content=[text_content]) @@ -104,7 +110,9 @@ class TestOrganizeHistoricPrompt: assert "Partial answer" in result - def test_assistant_list_with_non_text_content_ignored(self, runner, mocker: MockerFixture): + def test_assistant_list_with_non_text_content_ignored( + self, runner: CotCompletionAgentRunner, mocker: MockerFixture + ): non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png") assistant_msg = AssistantPromptMessage(content=[non_text_content]) @@ -117,7 +125,7 @@ class TestOrganizeHistoricPrompt: result = runner._organize_historic_prompt() assert result == "" - def test_empty_history(self, runner, mocker: MockerFixture): + def test_empty_history(self, runner: CotCompletionAgentRunner, mocker: MockerFixture): mocker.patch.object( runner, "_organize_historic_prompt_messages", diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 3a4347e7239..9b2a1d70fdf 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -147,12 +147,12 @@ def runner(mocker: MockerFixture): class TestToolCallChecks: @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) - def test_check_tool_calls(self, runner, tool_calls, expected): + def test_check_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected): chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls)) assert runner.check_tool_calls(chunk) is expected @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) - def test_check_blocking_tool_calls(self, runner, tool_calls, expected): + def test_check_blocking_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected): result = DummyResult(message=DummyMessage(tool_calls=tool_calls)) assert runner.check_blocking_tool_calls(result) is expected @@ -163,7 +163,7 @@ class TestToolCallChecks: class TestExtractToolCalls: - def test_extract_tool_calls_with_valid_json(self, runner): + def test_extract_tool_calls_with_valid_json(self, runner: FunctionCallAgentRunner): tool_call = MagicMock() tool_call.id = "1" tool_call.function.name = "tool" @@ -174,7 +174,7 @@ class TestExtractToolCalls: assert calls == [("1", "tool", {"a": 1})] - def test_extract_tool_calls_empty_arguments(self, runner): + def test_extract_tool_calls_empty_arguments(self, runner: FunctionCallAgentRunner): tool_call = MagicMock() tool_call.id = "1" tool_call.function.name = "tool" @@ -185,7 +185,7 @@ class TestExtractToolCalls: assert calls == [("1", "tool", {})] - def test_extract_blocking_tool_calls(self, runner): + def test_extract_blocking_tool_calls(self, runner: FunctionCallAgentRunner): tool_call = MagicMock() tool_call.id = "2" tool_call.function.name = "block" @@ -203,16 +203,16 @@ class TestExtractToolCalls: class TestInitSystemMessage: - def test_init_system_message_empty_prompt_messages(self, runner): + def test_init_system_message_empty_prompt_messages(self, runner: FunctionCallAgentRunner): result = runner._init_system_message("system", []) assert len(result) == 1 - def test_init_system_message_insert_at_start(self, runner): + def test_init_system_message_insert_at_start(self, runner: FunctionCallAgentRunner): msgs = [MagicMock()] result = runner._init_system_message("system", msgs) assert result[0].content == "system" - def test_init_system_message_no_template(self, runner): + def test_init_system_message_no_template(self, runner: FunctionCallAgentRunner): result = runner._init_system_message("", []) assert result == [] @@ -223,15 +223,15 @@ class TestInitSystemMessage: class TestOrganizeUserQuery: - def test_without_files(self, runner): + def test_without_files(self, runner: FunctionCallAgentRunner): result = runner._organize_user_query("query", []) assert len(result) == 1 - def test_with_none_query(self, runner): + def test_with_none_query(self, runner: FunctionCallAgentRunner): result = runner._organize_user_query(None, []) assert len(result) == 1 - def test_with_files_uses_image_detail_config(self, runner, mocker: MockerFixture): + def test_with_files_uses_image_detail_config(self, runner: FunctionCallAgentRunner, mocker: MockerFixture): file_content = TextPromptMessageContent(data="file-content") mock_to_prompt = mocker.patch( "core.agent.fc_agent_runner.file_manager.to_prompt_message_content", @@ -255,7 +255,7 @@ class TestOrganizeUserQuery: class TestClearUserPromptImageMessages: - def test_clear_text_and_image_content(self, runner): + def test_clear_text_and_image_content(self, runner: FunctionCallAgentRunner): text = MagicMock() text.type = "text" text.data = "hello" @@ -271,7 +271,7 @@ class TestClearUserPromptImageMessages: result = runner._clear_user_prompt_image_messages([user_msg]) assert isinstance(result, list) - def test_clear_includes_file_placeholder(self, runner): + def test_clear_includes_file_placeholder(self, runner: FunctionCallAgentRunner): text = TextPromptMessageContent(data="hello") image = ImagePromptMessageContent(format="url", mime_type="image/png") document = DocumentPromptMessageContent(format="url", mime_type="application/pdf") @@ -289,7 +289,7 @@ class TestClearUserPromptImageMessages: class TestRunMethod: - def test_run_non_streaming_no_tool_calls(self, runner): + def test_run_non_streaming_no_tool_calls(self, runner: FunctionCallAgentRunner): message = MagicMock(id="m1") dummy_message = DummyMessage(content="hello") result = DummyResult(message=dummy_message, usage=build_usage()) @@ -303,7 +303,7 @@ class TestRunMethod: queue_calls = runner.queue_manager.publish.call_args_list assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls) - def test_run_streaming_branch(self, runner): + def test_run_streaming_branch(self, runner: FunctionCallAgentRunner): message = MagicMock(id="m1") runner.stream_tool_call = True @@ -318,7 +318,7 @@ class TestRunMethod: outputs = list(runner.run(message, "query")) assert len(outputs) == 1 - def test_run_streaming_tool_calls_list_content(self, runner): + def test_run_streaming_tool_calls_list_content(self, runner: FunctionCallAgentRunner): message = MagicMock(id="m1") runner.stream_tool_call = True @@ -341,7 +341,7 @@ class TestRunMethod: outputs = list(runner.run(message, "query")) assert len(outputs) >= 1 - def test_run_non_streaming_list_content(self, runner): + def test_run_non_streaming_list_content(self, runner: FunctionCallAgentRunner): message = MagicMock(id="m1") content = [TextPromptMessageContent(data="hi")] dummy_message = DummyMessage(content=content) @@ -353,7 +353,7 @@ class TestRunMethod: assert len(outputs) == 1 assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi" - def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker: MockerFixture): + def test_run_streaming_tool_call_inputs_type_error(self, runner: FunctionCallAgentRunner, mocker: MockerFixture): message = MagicMock(id="m1") runner.stream_tool_call = True @@ -381,7 +381,7 @@ class TestRunMethod: outputs = list(runner.run(message, "query")) assert len(outputs) == 1 - def test_run_with_missing_tool_instance(self, runner): + def test_run_with_missing_tool_instance(self, runner: FunctionCallAgentRunner): message = MagicMock(id="m1") tool_call = MagicMock() @@ -399,7 +399,7 @@ class TestRunMethod: outputs = list(runner.run(message, "query")) assert len(outputs) >= 1 - def test_run_with_tool_instance_and_files(self, runner, mocker: MockerFixture): + def test_run_with_tool_instance_and_files(self, runner: FunctionCallAgentRunner, mocker: MockerFixture): message = MagicMock(id="m1") tool_call = MagicMock() @@ -434,7 +434,7 @@ class TestRunMethod: for call in runner.queue_manager.publish.call_args_list ) - def test_run_max_iteration_error(self, runner): + def test_run_max_iteration_error(self, runner: FunctionCallAgentRunner): runner.app_config.agent.max_iteration = 0 message = MagicMock(id="m1") diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py index 60eab201729..b8cdf471ca3 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py @@ -36,7 +36,7 @@ def generator(mocker: MockerFixture) -> AgentAppGenerator: class TestGenerateGuards: - def test_rejects_blocking_mode(self, generator, mocker: MockerFixture): + def test_rejects_blocking_mode(self, generator: AgentAppGenerator, mocker: MockerFixture): with pytest.raises(AgentAppGeneratorError, match="only supports streaming"): generator.generate( app_model=mocker.MagicMock(), @@ -46,7 +46,7 @@ class TestGenerateGuards: streaming=False, ) - def test_requires_query(self, generator, mocker: MockerFixture): + def test_requires_query(self, generator: AgentAppGenerator, mocker: MockerFixture): with pytest.raises(AgentAppGeneratorError, match="query is required"): generator.generate( app_model=mocker.MagicMock(), @@ -55,7 +55,7 @@ class TestGenerateGuards: invoke_from=InvokeFrom.WEB_APP, ) - def test_rejects_blank_query(self, generator, mocker: MockerFixture): + def test_rejects_blank_query(self, generator: AgentAppGenerator, mocker: MockerFixture): with pytest.raises(AgentAppGeneratorError, match="query is required"): generator.generate( app_model=mocker.MagicMock(), @@ -113,7 +113,7 @@ class TestGenerateSuccess: thread_obj.start.assert_called_once() generator._resolve_agent.assert_called_once_with(app_model) - def test_generate_loads_existing_conversation(self, generator, mocker: MockerFixture): + def test_generate_loads_existing_conversation(self, generator: AgentAppGenerator, mocker: MockerFixture): app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent") generator._resolve_agent = mocker.MagicMock( return_value=(mocker.MagicMock(id="a"), mocker.MagicMock(id="s"), mocker.MagicMock()) @@ -144,7 +144,9 @@ class TestGenerateSuccess: get_conv.assert_called_once() - def test_generate_does_not_include_trace_session_id_in_extras(self, generator, mocker: MockerFixture): + def test_generate_does_not_include_trace_session_id_in_extras( + self, generator: AgentAppGenerator, mocker: MockerFixture + ): app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent") user = DummyAccount("user") @@ -190,7 +192,7 @@ class TestGenerateWorker: mocker.patch("libs.flask_utils.preserve_flask_contexts", ctx_manager) - def _wire(self, generator, mocker: MockerFixture, *, run_side_effect=None, handled=False): + def _wire(self, generator: AgentAppGenerator, mocker: MockerFixture, *, run_side_effect=None, handled=False): generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock(id="conv")) generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock(id="msg")) generator._run_input_guards = mocker.MagicMock(return_value=(handled, "query")) @@ -238,7 +240,7 @@ class TestGenerateWorker: is_resume=is_resume, ) - def test_happy_path_runs_backend(self, generator, mocker: MockerFixture): + def test_happy_path_runs_backend(self, generator: AgentAppGenerator, mocker: MockerFixture): runner = self._wire(generator, mocker) queue_manager = mocker.MagicMock() self._call(generator, mocker, queue_manager) @@ -280,7 +282,7 @@ class TestGenerateWorker: self._call(generator, mocker, queue_manager) queue_manager.publish_error.assert_not_called() - def test_unexpected_error_is_published(self, generator, mocker: MockerFixture): + def test_unexpected_error_is_published(self, generator: AgentAppGenerator, mocker: MockerFixture): self._wire(generator, mocker, run_side_effect=ValueError("boom")) queue_manager = mocker.MagicMock() self._call(generator, mocker, queue_manager) diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 0260235b03a..d7988cbf74d 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -14,7 +14,7 @@ def runner(): class TestAgentChatAppRunnerRun: - def test_run_app_not_found(self, runner, mocker: MockerFixture): + def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock()) generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True) @@ -23,7 +23,7 @@ class TestAgentChatAppRunnerRun: with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) - def test_run_moderation_error_direct_output(self, runner, mocker: MockerFixture): + def test_run_moderation_error_direct_output(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = mocker.MagicMock() @@ -46,7 +46,7 @@ class TestAgentChatAppRunnerRun: runner.direct_output.assert_called_once() - def test_run_annotation_reply_short_circuits(self, runner, mocker: MockerFixture): + def test_run_annotation_reply_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = mocker.MagicMock() @@ -75,7 +75,7 @@ class TestAgentChatAppRunnerRun: queue_manager.publish.assert_called_once() runner.direct_output.assert_called_once() - def test_run_hosting_moderation_short_circuits(self, runner, mocker: MockerFixture): + def test_run_hosting_moderation_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = mocker.MagicMock() @@ -99,7 +99,7 @@ class TestAgentChatAppRunnerRun: runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) - def test_run_model_schema_missing(self, runner, mocker: MockerFixture): + def test_run_model_schema_missing(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) @@ -141,7 +141,7 @@ class TestAgentChatAppRunnerRun: (LLMMode.COMPLETION, "CotCompletionAgentRunner"), ], ) - def test_run_chain_of_thought_modes(self, runner, mocker: MockerFixture, mode, expected_runner): + def test_run_chain_of_thought_modes(self, runner: AgentChatAppRunner, mocker: MockerFixture, mode, expected_runner): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) @@ -197,7 +197,7 @@ class TestAgentChatAppRunnerRun: runner_instance.run.assert_called_once() runner._handle_invoke_result.assert_called_once() - def test_run_invalid_llm_mode_raises(self, runner, mocker: MockerFixture): + def test_run_invalid_llm_mode_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) @@ -243,7 +243,9 @@ class TestAgentChatAppRunnerRun: with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), conversation, message) - def test_run_function_calling_strategy_selected_by_features(self, runner, mocker: MockerFixture): + def test_run_function_calling_strategy_selected_by_features( + self, runner: AgentChatAppRunner, mocker: MockerFixture + ): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) @@ -299,7 +301,7 @@ class TestAgentChatAppRunnerRun: assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING runner_instance.run.assert_called_once() - def test_run_conversation_not_found(self, runner, mocker: MockerFixture): + def test_run_conversation_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) @@ -333,7 +335,7 @@ class TestAgentChatAppRunnerRun: with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) - def test_run_message_not_found(self, runner, mocker: MockerFixture): + def test_run_message_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) @@ -367,7 +369,7 @@ class TestAgentChatAppRunnerRun: with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) - def test_run_invalid_agent_strategy_raises(self, runner, mocker: MockerFixture): + def test_run_invalid_agent_strategy_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_record = mocker.MagicMock(id="app1", tenant_id="tenant") app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m") diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index c02683e13f4..f02994fd61c 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -112,7 +112,7 @@ def test_run_uses_single_node_execution_branch( assert entry_kwargs["graph_runtime_state"] is graph_runtime_state -def test_single_node_run_validates_target_node_config(monkeypatch) -> None: +def test_single_node_run_validates_target_node_config(monkeypatch: pytest.MonkeyPatch) -> None: runner = WorkflowBasedAppRunner( queue_manager=MagicMock(spec=AppQueueManager), variable_loader=MagicMock(), diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index f8e13ca8083..c8efaec33dc 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -218,7 +218,7 @@ class TestWorkflowPersistenceLayer: assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED assert trace_tasks - def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch): + def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch: pytest.MonkeyPatch): trace_tasks: list[TraceTask] = [] trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) layer, _, _, _ = _make_layer( diff --git a/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py index 426ffc498b2..a0d047f6eda 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py +++ b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py @@ -3,6 +3,7 @@ from __future__ import annotations from types import SimpleNamespace from uuid import uuid4 +import pytest from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -112,7 +113,7 @@ def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope assert "upload_files.id IN" in whereclause -def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None: +def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch: pytest.MonkeyPatch) -> None: tenant_id = str(uuid4()) dataset_id = str(uuid4()) document_id = str(uuid4()) diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index cf93f436486..b33a7ba725c 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -64,7 +64,9 @@ class TestCacheEmbeddingMultimodalDocuments: usage=usage, ) - def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): + def test_embed_single_multimodal_document_cache_miss( + self, mock_model_instance, sample_multimodal_result: EmbeddingResult + ): """Test embedding a single multimodal document when cache is empty.""" cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 8bc7dbf70db..565bb85b634 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -97,7 +97,9 @@ class TestRerankModelRunner: ), ] - def test_basic_reranking(self, rerank_runner, mock_model_instance, sample_documents): + def test_basic_reranking( + self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document] + ): """Test basic reranking with cross-encoder model. Verifies: @@ -135,7 +137,9 @@ class TestRerankModelRunner: assert result[3].metadata["score"] == 0.65 assert result[0].page_content == sample_documents[2].page_content - def test_score_threshold_filtering(self, rerank_runner, mock_model_instance, sample_documents): + def test_score_threshold_filtering( + self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document] + ): """Test score threshold filtering. Verifies: @@ -163,7 +167,9 @@ class TestRerankModelRunner: assert result[0].metadata["score"] == 0.90 assert result[1].metadata["score"] == 0.70 - def test_top_k_selection(self, rerank_runner, mock_model_instance, sample_documents): + def test_top_k_selection( + self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document] + ): """Test top-k selection functionality. Verifies: @@ -191,7 +197,7 @@ class TestRerankModelRunner: assert result[0].metadata["score"] == 0.95 assert result[1].metadata["score"] == 0.85 - def test_document_deduplication_dify_provider(self, rerank_runner, mock_model_instance): + def test_document_deduplication_dify_provider(self, rerank_runner: RerankModelRunner, mock_model_instance): """Test document deduplication for dify provider. Verifies: @@ -235,7 +241,7 @@ class TestRerankModelRunner: assert len(call_kwargs["docs"]) == 2 # Duplicate removed assert len(result) == 2 - def test_document_deduplication_external_provider(self, rerank_runner, mock_model_instance): + def test_document_deduplication_external_provider(self, rerank_runner: RerankModelRunner, mock_model_instance): """Test document deduplication for external provider. Verifies: @@ -273,7 +279,9 @@ class TestRerankModelRunner: assert len(call_kwargs["docs"]) == 2 assert len(result) == 2 - def test_combined_threshold_and_top_k(self, rerank_runner, mock_model_instance, sample_documents): + def test_combined_threshold_and_top_k( + self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document] + ): """Test combined score threshold and top-k selection. Verifies: @@ -307,7 +315,9 @@ class TestRerankModelRunner: assert result[0].metadata["score"] == 0.95 assert result[1].metadata["score"] == 0.85 - def test_metadata_preservation(self, rerank_runner, mock_model_instance, sample_documents): + def test_metadata_preservation( + self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document] + ): """Test that original metadata is preserved after reranking. Verifies: @@ -334,7 +344,7 @@ class TestRerankModelRunner: assert result[0].metadata["score"] == 0.90 assert result[0].provider == "dify" - def test_empty_documents_list(self, rerank_runner, mock_model_instance): + def test_empty_documents_list(self, rerank_runner: RerankModelRunner, mock_model_instance): """Test handling of empty documents list. Verifies: @@ -523,7 +533,9 @@ class TestRerankModelRunnerMultimodal: docs_arg = mock_text_rerank.call_args.args[1] assert len(docs_arg) == 1 - def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance): + def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model( + self, rerank_runner: RerankModelRunner, mock_model_instance + ): text_doc = Document( page_content="text-content", metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, diff --git a/api/tests/unit_tests/core/workflow/generator/test_runner.py b/api/tests/unit_tests/core/workflow/generator/test_runner.py index 067fb1cf950..ec7a8f32dfc 100644 --- a/api/tests/unit_tests/core/workflow/generator/test_runner.py +++ b/api/tests/unit_tests/core/workflow/generator/test_runner.py @@ -413,7 +413,7 @@ class TestWorkflowGeneratorTransientRetry: } ) - def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch): + def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch: pytest.MonkeyPatch): # The planner's first invoke raises a transient connection error; the # retry succeeds and the pipeline completes normally. Sleep is patched # out so the test doesn't actually wait for the backoff. @@ -443,7 +443,7 @@ class TestWorkflowGeneratorTransientRetry: # planner (failed once + retried) + builder = 3 invocations total. assert model_instance.invoke_llm.call_count == 3 - def test_gives_up_after_exhausting_transient_retries(self, monkeypatch): + def test_gives_up_after_exhausting_transient_retries(self, monkeypatch: pytest.MonkeyPatch): # Every attempt hits the transient error — once we exhaust the retry # budget the failure surfaces as a normal error envelope rather than # hanging or looping forever. @@ -471,7 +471,7 @@ class TestWorkflowGeneratorTransientRetry: assert model_instance.invoke_llm.call_count == _INVOKE_MAX_ATTEMPTS - def test_does_not_retry_permanent_invoke_error(self, monkeypatch): + def test_does_not_retry_permanent_invoke_error(self, monkeypatch: pytest.MonkeyPatch): # An auth error is permanent — retrying just burns latency and quota. # The runner must fail on the first attempt. # If the code wrongly slept here we'd want the test to still be fast; diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py index 51049f87923..7a6328ffb4b 100644 --- a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -332,7 +332,9 @@ def test_email_delivery_method_extracts_variable_selectors() -> None: assert method.extract_variable_selectors() == [["start", "name"]] -def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors(monkeypatch) -> None: +def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors( + monkeypatch: pytest.MonkeyPatch, +) -> None: method = EmailDeliveryMethod( enabled=True, config=EmailDeliveryConfig( diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py index a4739dbbc2b..3eec25815a1 100644 --- a/api/tests/unit_tests/libs/test_pandas.py +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -1,8 +1,10 @@ +from pathlib import Path + import pandas as pd import pytest -def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch): +def test_pandas_csv(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): monkeypatch.chdir(tmp_path) data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) @@ -17,7 +19,7 @@ def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch): assert df2[df2.columns[1]].to_list() == data["col2"] -def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch): +def test_pandas_xlsx(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): monkeypatch.chdir(tmp_path) data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) @@ -32,7 +34,7 @@ def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch): assert df2[df2.columns[1]].to_list() == data["col2"] -def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch: pytest.MonkeyPatch): +def test_pandas_xlsx_with_sheets(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): monkeypatch.chdir(tmp_path) data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data1) diff --git a/api/tests/unit_tests/libs/test_rate_limit_bearer.py b/api/tests/unit_tests/libs/test_rate_limit_bearer.py index 62363f5f600..f286bfd875d 100644 --- a/api/tests/unit_tests/libs/test_rate_limit_bearer.py +++ b/api/tests/unit_tests/libs/test_rate_limit_bearer.py @@ -75,7 +75,7 @@ def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build): @patch("libs.rate_limit._build_limiter") -def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch): +def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch: pytest.MonkeyPatch): # 0 disables the limit — short-circuit before building/consulting a limiter. monkeypatch.setattr( "libs.rate_limit.LIMIT_BEARER_PER_TOKEN", diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 44a4e6af98b..bbfd411ca38 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -17,6 +17,8 @@ from unittest.mock import Mock, patch from urllib.parse import parse_qs, urlparse from uuid import uuid4 +import pytest + from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import ( AppDatasetJoin, @@ -703,7 +705,7 @@ class TestDocumentSegmentIndexing: # Assert assert segment.hit_count == 5 - def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch): + def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch: pytest.MonkeyPatch): """Test attachment source URLs use FILES_URL before falling back to CONSOLE_API_URL.""" # Arrange segment = DocumentSegment( diff --git a/api/tests/unit_tests/models/test_file_input_compat.py b/api/tests/unit_tests/models/test_file_input_compat.py index 1a41ccec9ea..6d6f2b42201 100644 --- a/api/tests/unit_tests/models/test_file_input_compat.py +++ b/api/tests/unit_tests/models/test_file_input_compat.py @@ -125,7 +125,9 @@ def test_rebuild_serialized_graph_files_without_lookup_preserves_scalar_values() assert rebuild_serialized_graph_files_without_lookup("plain-text") == "plain-text" -def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup(monkeypatch) -> None: +def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup( + monkeypatch: pytest.MonkeyPatch, +) -> None: rebuilt_file = File( file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, diff --git a/api/tests/unit_tests/models/test_snippet.py b/api/tests/unit_tests/models/test_snippet.py index f7a22f48f89..17f7cb3c9d4 100644 --- a/api/tests/unit_tests/models/test_snippet.py +++ b/api/tests/unit_tests/models/test_snippet.py @@ -2,6 +2,8 @@ import json from types import SimpleNamespace from unittest.mock import Mock +import pytest + from models.snippet import CustomizedSnippet @@ -11,7 +13,7 @@ def test_graph_dict_returns_empty_without_workflow_id() -> None: assert snippet.graph_dict == {} -def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None: +def test_graph_dict_loads_published_workflow_graph(monkeypatch: pytest.MonkeyPatch) -> None: workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"id": "llm-1"}], "edges": []})) session = SimpleNamespace(get=Mock(return_value=workflow)) monkeypatch.setattr("models.snippet.db.session", session) @@ -21,7 +23,7 @@ def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None: session.get.assert_called_once() -def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch) -> None: +def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch: pytest.MonkeyPatch) -> None: session = SimpleNamespace(get=Mock(return_value=None)) monkeypatch.setattr("models.snippet.db.session", session) snippet = CustomizedSnippet(workflow_id="missing-workflow") @@ -36,7 +38,7 @@ def test_input_fields_list_parses_json_or_returns_empty() -> None: ] -def test_tags_returns_query_results_or_empty(monkeypatch) -> None: +def test_tags_returns_query_results_or_empty(monkeypatch: pytest.MonkeyPatch) -> None: tags = [SimpleNamespace(id="tag-1")] session = SimpleNamespace(scalars=Mock(return_value=SimpleNamespace(all=Mock(return_value=tags)))) monkeypatch.setattr("models.snippet.db.session", session) @@ -48,7 +50,7 @@ def test_tags_returns_query_results_or_empty(monkeypatch) -> None: assert snippet.tags == [] -def test_account_properties_and_author_name(monkeypatch) -> None: +def test_account_properties_and_author_name(monkeypatch: pytest.MonkeyPatch) -> None: account = SimpleNamespace(id="account-1", name="Ada") updated_account = SimpleNamespace(id="account-2", name="Grace") session = SimpleNamespace( diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index 816dc39ed79..2c077e20b46 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -1324,7 +1324,7 @@ class TestAgentAppBackingAgent: with pytest.raises(roster_service.AgentNotFoundError): service.get_agent_app_model(tenant_id="tenant-1", agent_id="agent-x") - def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch): + def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch: pytest.MonkeyPatch): source_config = SimpleNamespace( opening_statement="hello", suggested_questions='["q1"]', @@ -1463,7 +1463,7 @@ class TestAgentAppBackingAgent: assert target_agent.updated_by == "account-1" assert session.commits == 1 - def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch): + def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch: pytest.MonkeyPatch): source_app = SimpleNamespace( id="source-app", tenant_id="tenant-1", @@ -1522,7 +1522,7 @@ class TestAgentAppBackingAgent: assert duplicated is target_app assert access_mode_updates == [("target-app", "private")] - def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch): + def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch: pytest.MonkeyPatch): source_app = SimpleNamespace( id="source-app", tenant_id="tenant-1", diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py index 1458180570d..1369ff2ba05 100644 --- a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py +++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py @@ -66,7 +66,7 @@ class TestFirecrawlAuth: assert str(exc_info.value) == expected_error @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) - def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): + def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock, auth_instance: FirecrawlAuth): """Test successful credential validation""" mock_response = MagicMock() mock_response.status_code = 200 @@ -97,7 +97,9 @@ class TestFirecrawlAuth: ], ) @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) - def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): + def test_should_handle_http_errors( + self, mock_post: MagicMock, status_code, error_message, auth_instance: FirecrawlAuth + ): """Test handling of various HTTP error codes""" mock_response = MagicMock() mock_response.status_code = status_code @@ -120,7 +122,13 @@ class TestFirecrawlAuth: ) @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_handle_unexpected_errors( - self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance + self, + mock_post: MagicMock, + status_code, + response_text, + has_json_error, + expected_error_contains, + auth_instance: FirecrawlAuth, ): """Test handling of unexpected errors with various response formats""" mock_response = MagicMock() @@ -146,7 +154,9 @@ class TestFirecrawlAuth: ], ) @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) - def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): + def test_should_handle_network_errors( + self, mock_post: MagicMock, exception_type, exception_message, auth_instance: FirecrawlAuth + ): """Test handling of various network-related errors including timeouts""" mock_post.side_effect = exception_type(exception_message) @@ -186,7 +196,7 @@ class TestFirecrawlAuth: assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) - def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): + def test_should_handle_timeout_with_retry_suggestion(self, mock_post: MagicMock, auth_instance: FirecrawlAuth): """Test that timeout errors are handled gracefully with appropriate error message""" mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py index eb409c61d4e..e9a1035da29 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -36,7 +36,7 @@ class TestJinaAuth: assert str(exc_info.value) == "No API key provided" @patch("services.auth.jina.jina._http_client.post", autospec=True) - def test_should_validate_valid_credentials_successfully(self, mock_post): + def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock): """Test successful credential validation""" mock_response = MagicMock() mock_response.status_code = 200 @@ -54,7 +54,7 @@ class TestJinaAuth: ) @patch("services.auth.jina.jina._http_client.post", autospec=True) - def test_should_handle_http_402_error(self, mock_post): + def test_should_handle_http_402_error(self, mock_post: MagicMock): """Test handling of 402 Payment Required error""" mock_response = MagicMock() mock_response.status_code = 402 @@ -100,7 +100,7 @@ class TestJinaAuth: assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" @patch("services.auth.jina.jina._http_client.post", autospec=True) - def test_should_handle_http_500_error(self, mock_post): + def test_should_handle_http_500_error(self, mock_post: MagicMock): """Test handling of 500 Internal Server Error""" mock_response = MagicMock() mock_response.status_code = 500 @@ -115,7 +115,7 @@ class TestJinaAuth: assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" @patch("services.auth.jina.jina._http_client.post", autospec=True) - def test_should_handle_unexpected_error_with_text_response(self, mock_post): + def test_should_handle_unexpected_error_with_text_response(self, mock_post: MagicMock): """Test handling of unexpected errors with text response""" mock_response = MagicMock() mock_response.status_code = 403 @@ -163,7 +163,7 @@ class TestJinaAuth: assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" @patch("services.auth.jina.jina._http_client.post", autospec=True) - def test_should_handle_network_errors(self, mock_post): + def test_should_handle_network_errors(self, mock_post: MagicMock): """Test handling of network connection errors""" mock_post.side_effect = httpx.ConnectError("Network error") diff --git a/api/tests/unit_tests/services/test_agent_app_feature_service.py b/api/tests/unit_tests/services/test_agent_app_feature_service.py index a8553b62a86..5503540356f 100644 --- a/api/tests/unit_tests/services/test_agent_app_feature_service.py +++ b/api/tests/unit_tests/services/test_agent_app_feature_service.py @@ -89,7 +89,7 @@ class _FakeWriteSession: class TestUpdateFeatures: - def test_persists_new_app_model_config_version(self, monkeypatch): + def test_persists_new_app_model_config_version(self, monkeypatch: pytest.MonkeyPatch): session = _FakeWriteSession() monkeypatch.setattr(svc_mod.db, "session", session) app_model = SimpleNamespace( diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index 143283c0ae9..248a2b1f68a 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -137,14 +137,14 @@ class ExternalDatasetServiceTestDataFactory: @pytest.fixture def factory(): """Provide the test data factory to all tests.""" - return ExternalDatasetServiceTestDataFactory + return ExternalDatasetServiceTestDataFactory() class TestExternalDatasetServiceGetAPIs: """Test get_external_knowledge_apis operations - comprehensive coverage.""" @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_success_basic(self, mock_db, factory): + def test_get_external_knowledge_apis_success_basic(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test successful retrieval of external knowledge APIs with pagination.""" # Arrange tenant_id = "tenant-123" @@ -171,7 +171,9 @@ class TestExternalDatasetServiceGetAPIs: mock_db.paginate.assert_called_once() @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_with_search_filter(self, mock_db, factory): + def test_get_external_knowledge_apis_with_search_filter( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test retrieval with search filter.""" # Arrange tenant_id = "tenant-123" @@ -195,7 +197,7 @@ class TestExternalDatasetServiceGetAPIs: assert result_items[0].name == "Production API" @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_empty_results(self, mock_db, factory): + def test_get_external_knowledge_apis_empty_results(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test retrieval with no results.""" # Arrange mock_pagination = MagicMock() @@ -213,7 +215,9 @@ class TestExternalDatasetServiceGetAPIs: assert result_total == 0 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_large_result_set(self, mock_db, factory): + def test_get_external_knowledge_apis_large_result_set( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test retrieval with large result set.""" # Arrange apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)] @@ -233,7 +237,9 @@ class TestExternalDatasetServiceGetAPIs: assert result_total == 100 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_pagination_last_page(self, mock_db, factory): + def test_get_external_knowledge_apis_pagination_last_page( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test last page pagination with partial results.""" # Arrange apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(95, 100)] @@ -253,7 +259,9 @@ class TestExternalDatasetServiceGetAPIs: assert result_total == 100 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_case_insensitive_search(self, mock_db, factory): + def test_get_external_knowledge_apis_case_insensitive_search( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test case-insensitive search functionality.""" # Arrange apis = [ @@ -276,7 +284,9 @@ class TestExternalDatasetServiceGetAPIs: assert result_total == 2 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_special_characters_search(self, mock_db, factory): + def test_get_external_knowledge_apis_special_characters_search( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test search with special characters.""" # Arrange apis = [factory.create_external_knowledge_api_mock(name="API-v2.0 (beta)")] @@ -295,7 +305,9 @@ class TestExternalDatasetServiceGetAPIs: assert len(result_items) == 1 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_max_per_page_limit(self, mock_db, factory): + def test_get_external_knowledge_apis_max_per_page_limit( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test that max_per_page limit is enforced.""" # Arrange apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)] @@ -315,7 +327,9 @@ class TestExternalDatasetServiceGetAPIs: assert call_args.kwargs["max_per_page"] == 100 @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_apis_ordered_by_created_at_desc(self, mock_db, factory): + def test_get_external_knowledge_apis_ordered_by_created_at_desc( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test that results are ordered by created_at descending.""" # Arrange apis = [ @@ -340,7 +354,7 @@ class TestExternalDatasetServiceGetAPIs: class TestExternalDatasetServiceValidateAPIList: """Test validate_api_list operations.""" - def test_validate_api_list_success_with_all_fields(self, factory): + def test_validate_api_list_success_with_all_fields(self, factory: ExternalDatasetServiceTestDataFactory): """Test successful validation with all required fields.""" # Arrange api_settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"} @@ -348,7 +362,7 @@ class TestExternalDatasetServiceValidateAPIList: # Act & Assert - should not raise ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_missing_endpoint(self, factory): + def test_validate_api_list_missing_endpoint(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when endpoint is missing.""" # Arrange api_settings = {"api_key": "test-key"} @@ -357,7 +371,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="endpoint is required"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_empty_endpoint(self, factory): + def test_validate_api_list_empty_endpoint(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when endpoint is empty string.""" # Arrange api_settings = {"endpoint": "", "api_key": "test-key"} @@ -366,7 +380,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="endpoint is required"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_missing_api_key(self, factory): + def test_validate_api_list_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when API key is missing.""" # Arrange api_settings = {"endpoint": "https://api.example.com"} @@ -375,7 +389,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="api_key is required"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_empty_api_key(self, factory): + def test_validate_api_list_empty_api_key(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when API key is empty string.""" # Arrange api_settings = {"endpoint": "https://api.example.com", "api_key": ""} @@ -384,7 +398,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="api_key is required"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_empty_dict(self, factory): + def test_validate_api_list_empty_dict(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when settings are empty dict.""" # Arrange api_settings = {} @@ -393,7 +407,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="api list is empty"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_none_value(self, factory): + def test_validate_api_list_none_value(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when settings are None.""" # Arrange api_settings = None @@ -402,7 +416,7 @@ class TestExternalDatasetServiceValidateAPIList: with pytest.raises(ValueError, match="api list is empty"): ExternalDatasetService.validate_api_list(api_settings) - def test_validate_api_list_with_extra_fields(self, factory): + def test_validate_api_list_with_extra_fields(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation succeeds with extra fields present.""" # Arrange api_settings = { @@ -421,7 +435,9 @@ class TestExternalDatasetServiceCreateAPI: @patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key") - def test_create_external_knowledge_api_success_full(self, mock_check, mock_db, factory): + def test_create_external_knowledge_api_success_full( + self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test successful creation with all fields.""" # Arrange tenant_id = "tenant-123" @@ -447,7 +463,9 @@ class TestExternalDatasetServiceCreateAPI: @patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key") - def test_create_external_knowledge_api_minimal_fields(self, mock_check, mock_db, factory): + def test_create_external_knowledge_api_minimal_fields( + self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test creation with minimal required fields.""" # Arrange args = { @@ -463,7 +481,9 @@ class TestExternalDatasetServiceCreateAPI: assert result.description == "" @patch("services.external_knowledge_service.db") - def test_create_external_knowledge_api_missing_settings(self, mock_db, factory): + def test_create_external_knowledge_api_missing_settings( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test creation fails when settings are missing.""" # Arrange args = {"name": "Test API", "description": "Test"} @@ -473,7 +493,7 @@ class TestExternalDatasetServiceCreateAPI: ExternalDatasetService.create_external_knowledge_api("tenant-123", "user-123", args) @patch("services.external_knowledge_service.db") - def test_create_external_knowledge_api_none_settings(self, mock_db, factory): + def test_create_external_knowledge_api_none_settings(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test creation fails when settings are explicitly None.""" # Arrange args = {"name": "Test API", "settings": None} @@ -484,7 +504,9 @@ class TestExternalDatasetServiceCreateAPI: @patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key") - def test_create_external_knowledge_api_settings_json_serialization(self, mock_check, mock_db, factory): + def test_create_external_knowledge_api_settings_json_serialization( + self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test that settings are properly JSON serialized.""" # Arrange settings = { @@ -504,7 +526,9 @@ class TestExternalDatasetServiceCreateAPI: @patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key") - def test_create_external_knowledge_api_unicode_handling(self, mock_check, mock_db, factory): + def test_create_external_knowledge_api_unicode_handling( + self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test proper handling of Unicode characters in name and description.""" # Arrange args = { @@ -522,7 +546,9 @@ class TestExternalDatasetServiceCreateAPI: @patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key") - def test_create_external_knowledge_api_long_description(self, mock_check, mock_db, factory): + def test_create_external_knowledge_api_long_description( + self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test creation with very long description.""" # Arrange long_description = "A" * 1000 @@ -544,7 +570,7 @@ class TestExternalDatasetServiceCheckEndpoint: """Test check_endpoint_and_api_key operations - extensive coverage.""" @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_success_https(self, mock_proxy, factory): + def test_check_endpoint_success_https(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test successful validation with HTTPS endpoint.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -558,7 +584,7 @@ class TestExternalDatasetServiceCheckEndpoint: mock_proxy.post.assert_called_once() @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_success_http(self, mock_proxy, factory): + def test_check_endpoint_success_http(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test successful validation with HTTP endpoint.""" # Arrange settings = {"endpoint": "http://api.example.com", "api_key": "test-key"} @@ -570,7 +596,7 @@ class TestExternalDatasetServiceCheckEndpoint: # Act & Assert - should not raise ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_missing_endpoint_key(self, factory): + def test_check_endpoint_missing_endpoint_key(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when endpoint key is missing.""" # Arrange settings = {"api_key": "test-key"} @@ -579,7 +605,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="endpoint is required"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_empty_endpoint_string(self, factory): + def test_check_endpoint_empty_endpoint_string(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when endpoint is empty string.""" # Arrange settings = {"endpoint": "", "api_key": "test-key"} @@ -588,7 +614,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="endpoint is required"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_whitespace_endpoint(self, factory): + def test_check_endpoint_whitespace_endpoint(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when endpoint is only whitespace.""" # Arrange settings = {"endpoint": " ", "api_key": "test-key"} @@ -597,7 +623,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="invalid endpoint"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_missing_api_key_key(self, factory): + def test_check_endpoint_missing_api_key_key(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when api_key key is missing.""" # Arrange settings = {"endpoint": "https://api.example.com"} @@ -606,7 +632,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="api_key is required"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_empty_api_key_string(self, factory): + def test_check_endpoint_empty_api_key_string(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when api_key is empty string.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": ""} @@ -615,7 +641,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="api_key is required"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_no_scheme_url(self, factory): + def test_check_endpoint_no_scheme_url(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails for URL without http:// or https://.""" # Arrange settings = {"endpoint": "api.example.com", "api_key": "test-key"} @@ -624,7 +650,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="invalid endpoint.*must start with http"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_invalid_scheme(self, factory): + def test_check_endpoint_invalid_scheme(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails for URL with invalid scheme.""" # Arrange settings = {"endpoint": "ftp://api.example.com", "api_key": "test-key"} @@ -633,7 +659,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="failed to connect to the endpoint"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_no_netloc(self, factory): + def test_check_endpoint_no_netloc(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails for URL without network location.""" # Arrange settings = {"endpoint": "http://", "api_key": "test-key"} @@ -642,7 +668,7 @@ class TestExternalDatasetServiceCheckEndpoint: with pytest.raises(ValueError, match="invalid endpoint"): ExternalDatasetService.check_endpoint_and_api_key(settings) - def test_check_endpoint_malformed_url(self, factory): + def test_check_endpoint_malformed_url(self, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails for malformed URL.""" # Arrange settings = {"endpoint": "https:///invalid", "api_key": "test-key"} @@ -652,7 +678,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_connection_timeout(self, mock_proxy, factory): + def test_check_endpoint_connection_timeout(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails on connection timeout.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -663,7 +689,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_network_error(self, mock_proxy, factory): + def test_check_endpoint_network_error(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails on network error.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -674,7 +700,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory): + def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails with 502 Bad Gateway.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -688,7 +714,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_404_not_found(self, mock_proxy, factory): + def test_check_endpoint_404_not_found(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails with 404 Not Found.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -702,7 +728,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_403_forbidden(self, mock_proxy, factory): + def test_check_endpoint_403_forbidden(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails with 403 Forbidden (auth failure).""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "wrong-key"} @@ -716,7 +742,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory): + def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test that other 4xx codes don't raise exceptions.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -730,7 +756,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory): + def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test that 5xx codes except 502 don't raise exceptions.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key"} @@ -744,7 +770,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_with_port_number(self, mock_proxy, factory): + def test_check_endpoint_with_port_number(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation with endpoint including port number.""" # Arrange settings = {"endpoint": "https://api.example.com:8443", "api_key": "test-key"} @@ -757,7 +783,7 @@ class TestExternalDatasetServiceCheckEndpoint: ExternalDatasetService.check_endpoint_and_api_key(settings) @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_with_path(self, mock_proxy, factory): + def test_check_endpoint_with_path(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test validation with endpoint including path.""" # Arrange settings = {"endpoint": "https://api.example.com/v1/api", "api_key": "test-key"} @@ -773,7 +799,9 @@ class TestExternalDatasetServiceCheckEndpoint: assert "/retrieval" in call_args[0][0] @patch("services.external_knowledge_service.ssrf_proxy") - def test_check_endpoint_authorization_header_format(self, mock_proxy, factory): + def test_check_endpoint_authorization_header_format( + self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory + ): """Test that Authorization header is properly formatted.""" # Arrange settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"} @@ -795,7 +823,7 @@ class TestExternalDatasetServiceGetAPI: """Test get_external_knowledge_api operations.""" @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_api_success(self, mock_db, factory): + def test_get_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test successful retrieval of external knowledge API.""" # Arrange api_id = "api-123" @@ -811,7 +839,7 @@ class TestExternalDatasetServiceGetAPI: assert result.id == api_id @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_api_not_found(self, mock_db, factory): + def test_get_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test error when API is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -826,7 +854,9 @@ class TestExternalDatasetServiceUpdateAPI: @patch("services.external_knowledge_service.naive_utc_now") @patch("services.external_knowledge_service.db") - def test_update_external_knowledge_api_success_all_fields(self, mock_db, mock_now, factory): + def test_update_external_knowledge_api_success_all_fields( + self, mock_db, mock_now, factory: ExternalDatasetServiceTestDataFactory + ): """Test successful update with all fields.""" # Arrange api_id = "api-123" @@ -856,7 +886,9 @@ class TestExternalDatasetServiceUpdateAPI: mock_db.session.commit.assert_called_once() @patch("services.external_knowledge_service.db") - def test_update_external_knowledge_api_preserve_hidden_api_key(self, mock_db, factory): + def test_update_external_knowledge_api_preserve_hidden_api_key( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test that hidden API key is preserved from existing settings.""" # Arrange api_id = "api-123" @@ -883,7 +915,7 @@ class TestExternalDatasetServiceUpdateAPI: assert settings["api_key"] == "original-secret-key" @patch("services.external_knowledge_service.db") - def test_update_external_knowledge_api_not_found(self, mock_db, factory): + def test_update_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test error when API is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -895,7 +927,9 @@ class TestExternalDatasetServiceUpdateAPI: ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args) @patch("services.external_knowledge_service.db") - def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory): + def test_update_external_knowledge_api_tenant_mismatch( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when tenant ID doesn't match.""" # Arrange mock_db.session.scalar.return_value = None @@ -907,7 +941,7 @@ class TestExternalDatasetServiceUpdateAPI: ExternalDatasetService.update_external_knowledge_api("wrong-tenant", "user-123", "api-123", args) @patch("services.external_knowledge_service.db") - def test_update_external_knowledge_api_name_only(self, mock_db, factory): + def test_update_external_knowledge_api_name_only(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test updating only the name field.""" # Arrange existing_api = factory.create_external_knowledge_api_mock( @@ -930,7 +964,7 @@ class TestExternalDatasetServiceDeleteAPI: """Test delete_external_knowledge_api operations.""" @patch("services.external_knowledge_service.db") - def test_delete_external_knowledge_api_success(self, mock_db, factory): + def test_delete_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test successful deletion of external knowledge API.""" # Arrange api_id = "api-123" @@ -948,7 +982,7 @@ class TestExternalDatasetServiceDeleteAPI: mock_db.session.commit.assert_called_once() @patch("services.external_knowledge_service.db") - def test_delete_external_knowledge_api_not_found(self, mock_db, factory): + def test_delete_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test error when API is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -958,7 +992,9 @@ class TestExternalDatasetServiceDeleteAPI: ExternalDatasetService.delete_external_knowledge_api("tenant-123", "api-123") @patch("services.external_knowledge_service.db") - def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory): + def test_delete_external_knowledge_api_tenant_mismatch( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when tenant ID doesn't match.""" # Arrange mock_db.session.scalar.return_value = None @@ -972,7 +1008,9 @@ class TestExternalDatasetServiceAPIUseCheck: """Test external_knowledge_api_use_check operations.""" @patch("services.external_knowledge_service.db") - def test_external_knowledge_api_use_check_in_use_single(self, mock_db, factory): + def test_external_knowledge_api_use_check_in_use_single( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test API use check when API has one binding.""" # Arrange api_id = "api-123" @@ -989,7 +1027,9 @@ class TestExternalDatasetServiceAPIUseCheck: assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0]) @patch("services.external_knowledge_service.db") - def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory): + def test_external_knowledge_api_use_check_in_use_multiple( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test API use check with multiple bindings.""" # Arrange api_id = "api-123" @@ -1005,7 +1045,7 @@ class TestExternalDatasetServiceAPIUseCheck: assert count == 10 @patch("services.external_knowledge_service.db") - def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory): + def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test API use check when API is not in use.""" # Arrange api_id = "api-123" @@ -1025,7 +1065,7 @@ class TestExternalDatasetServiceGetBinding: """Test get_external_knowledge_binding_with_dataset_id operations.""" @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_binding_success(self, mock_db, factory): + def test_get_external_knowledge_binding_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test successful retrieval of external knowledge binding.""" # Arrange tenant_id = "tenant-123" @@ -1043,7 +1083,7 @@ class TestExternalDatasetServiceGetBinding: assert result.tenant_id == tenant_id @patch("services.external_knowledge_service.db") - def test_get_external_knowledge_binding_not_found(self, mock_db, factory): + def test_get_external_knowledge_binding_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test error when binding is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -1057,7 +1097,9 @@ class TestExternalDatasetServiceDocumentValidate: """Test document_create_args_validate operations.""" @patch("services.external_knowledge_service.db") - def test_document_create_args_validate_success_all_params(self, mock_db, factory): + def test_document_create_args_validate_success_all_params( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test successful validation with all required parameters.""" # Arrange tenant_id = "tenant-123" @@ -1081,7 +1123,9 @@ class TestExternalDatasetServiceDocumentValidate: ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter) @patch("services.external_knowledge_service.db") - def test_document_create_args_validate_missing_required_param(self, mock_db, factory): + def test_document_create_args_validate_missing_required_param( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test validation fails when required parameter is missing.""" # Arrange tenant_id = "tenant-123" @@ -1100,7 +1144,7 @@ class TestExternalDatasetServiceDocumentValidate: ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter) @patch("services.external_knowledge_service.db") - def test_document_create_args_validate_api_not_found(self, mock_db, factory): + def test_document_create_args_validate_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test validation fails when API is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -1110,7 +1154,9 @@ class TestExternalDatasetServiceDocumentValidate: ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {}) @patch("services.external_knowledge_service.db") - def test_document_create_args_validate_no_custom_parameters(self, mock_db, factory): + def test_document_create_args_validate_no_custom_parameters( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test validation succeeds when no custom parameters defined.""" # Arrange settings = {} @@ -1122,7 +1168,9 @@ class TestExternalDatasetServiceDocumentValidate: ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {}) @patch("services.external_knowledge_service.db") - def test_document_create_args_validate_optional_params_not_required(self, mock_db, factory): + def test_document_create_args_validate_optional_params_not_required( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test that optional parameters don't cause validation failure.""" # Arrange settings = { @@ -1146,7 +1194,7 @@ class TestExternalDatasetServiceProcessAPI: """Test process_external_api operations - comprehensive HTTP method coverage.""" @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_get_request(self, mock_proxy, factory): + def test_process_external_api_get_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing GET request.""" # Arrange settings = factory.create_api_setting_mock(request_method="get") @@ -1162,7 +1210,9 @@ class TestExternalDatasetServiceProcessAPI: mock_proxy.get.assert_called_once() @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_post_request_with_data(self, mock_proxy, factory): + def test_process_external_api_post_request_with_data( + self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory + ): """Test processing POST request with data.""" # Arrange settings = factory.create_api_setting_mock(request_method="post", params={"key": "value", "data": "test"}) @@ -1180,7 +1230,7 @@ class TestExternalDatasetServiceProcessAPI: assert "data" in call_kwargs @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_put_request(self, mock_proxy, factory): + def test_process_external_api_put_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing PUT request.""" # Arrange settings = factory.create_api_setting_mock(request_method="put") @@ -1196,7 +1246,7 @@ class TestExternalDatasetServiceProcessAPI: mock_proxy.put.assert_called_once() @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_delete_request(self, mock_proxy, factory): + def test_process_external_api_delete_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing DELETE request.""" # Arrange settings = factory.create_api_setting_mock(request_method="delete") @@ -1212,7 +1262,7 @@ class TestExternalDatasetServiceProcessAPI: mock_proxy.delete.assert_called_once() @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_patch_request(self, mock_proxy, factory): + def test_process_external_api_patch_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing PATCH request.""" # Arrange settings = factory.create_api_setting_mock(request_method="patch") @@ -1228,7 +1278,7 @@ class TestExternalDatasetServiceProcessAPI: mock_proxy.patch.assert_called_once() @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_head_request(self, mock_proxy, factory): + def test_process_external_api_head_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing HEAD request.""" # Arrange settings = factory.create_api_setting_mock(request_method="head") @@ -1243,7 +1293,7 @@ class TestExternalDatasetServiceProcessAPI: assert result == mock_response mock_proxy.head.assert_called_once() - def test_process_external_api_invalid_method(self, factory): + def test_process_external_api_invalid_method(self, factory: ExternalDatasetServiceTestDataFactory): """Test error for invalid HTTP method.""" # Arrange settings = factory.create_api_setting_mock(request_method="INVALID") @@ -1253,7 +1303,7 @@ class TestExternalDatasetServiceProcessAPI: ExternalDatasetService.process_external_api(settings, None) @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_with_files(self, mock_proxy, factory): + def test_process_external_api_with_files(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test processing request with file uploads.""" # Arrange settings = factory.create_api_setting_mock(request_method="post") @@ -1272,7 +1322,7 @@ class TestExternalDatasetServiceProcessAPI: assert call_kwargs["files"] == files @patch("services.external_knowledge_service.ssrf_proxy") - def test_process_external_api_follow_redirects(self, mock_proxy, factory): + def test_process_external_api_follow_redirects(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory): """Test that follow_redirects is enabled.""" # Arrange settings = factory.create_api_setting_mock(request_method="get") @@ -1291,7 +1341,7 @@ class TestExternalDatasetServiceProcessAPI: class TestExternalDatasetServiceAssemblingHeaders: """Test assembling_headers operations - comprehensive authorization coverage.""" - def test_assembling_headers_bearer_token(self, factory): + def test_assembling_headers_bearer_token(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers with Bearer token.""" # Arrange authorization = factory.create_authorization_mock(token_type="bearer", api_key="secret-key-123") @@ -1302,7 +1352,7 @@ class TestExternalDatasetServiceAssemblingHeaders: # Assert assert result["Authorization"] == "Bearer secret-key-123" - def test_assembling_headers_basic_auth(self, factory): + def test_assembling_headers_basic_auth(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers with Basic authentication.""" # Arrange authorization = factory.create_authorization_mock(token_type="basic", api_key="credentials") @@ -1313,7 +1363,7 @@ class TestExternalDatasetServiceAssemblingHeaders: # Assert assert result["Authorization"] == "Basic credentials" - def test_assembling_headers_custom_auth(self, factory): + def test_assembling_headers_custom_auth(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers with custom authentication.""" # Arrange authorization = factory.create_authorization_mock(token_type="custom", api_key="custom-token") @@ -1324,7 +1374,7 @@ class TestExternalDatasetServiceAssemblingHeaders: # Assert assert result["Authorization"] == "custom-token" - def test_assembling_headers_custom_header_name(self, factory): + def test_assembling_headers_custom_header_name(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers with custom header name.""" # Arrange authorization = factory.create_authorization_mock(token_type="bearer", api_key="key-123", header="X-API-Key") @@ -1336,7 +1386,7 @@ class TestExternalDatasetServiceAssemblingHeaders: assert result["X-API-Key"] == "Bearer key-123" assert "Authorization" not in result - def test_assembling_headers_with_existing_headers(self, factory): + def test_assembling_headers_with_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers preserves existing headers.""" # Arrange authorization = factory.create_authorization_mock(token_type="bearer", api_key="key") @@ -1355,7 +1405,7 @@ class TestExternalDatasetServiceAssemblingHeaders: assert result["X-Custom"] == "value" assert result["User-Agent"] == "TestAgent/1.0" - def test_assembling_headers_empty_existing_headers(self, factory): + def test_assembling_headers_empty_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory): """Test assembling headers with empty existing headers dict.""" # Arrange authorization = factory.create_authorization_mock(token_type="bearer", api_key="key") @@ -1368,7 +1418,7 @@ class TestExternalDatasetServiceAssemblingHeaders: assert result["Authorization"] == "Bearer key" assert len(result) == 1 - def test_assembling_headers_missing_api_key(self, factory): + def test_assembling_headers_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory): """Test error when API key is missing.""" # Arrange config = AuthorizationConfig(api_key=None, type="bearer", header="Authorization") @@ -1378,7 +1428,7 @@ class TestExternalDatasetServiceAssemblingHeaders: with pytest.raises(ValueError, match="api_key is required"): ExternalDatasetService.assembling_headers(authorization) - def test_assembling_headers_missing_config(self, factory): + def test_assembling_headers_missing_config(self, factory: ExternalDatasetServiceTestDataFactory): """Test error when config is missing.""" # Arrange authorization = Authorization(type="api-key", config=None) @@ -1387,7 +1437,7 @@ class TestExternalDatasetServiceAssemblingHeaders: with pytest.raises(ValueError, match="authorization config is required"): ExternalDatasetService.assembling_headers(authorization) - def test_assembling_headers_default_header_name(self, factory): + def test_assembling_headers_default_header_name(self, factory: ExternalDatasetServiceTestDataFactory): """Test that default header name is Authorization when not specified.""" # Arrange config = AuthorizationConfig(api_key="key", type="bearer", header=None) @@ -1403,7 +1453,7 @@ class TestExternalDatasetServiceAssemblingHeaders: class TestExternalDatasetServiceGetSettings: """Test get_external_knowledge_api_settings operations.""" - def test_get_external_knowledge_api_settings_success(self, factory): + def test_get_external_knowledge_api_settings_success(self, factory: ExternalDatasetServiceTestDataFactory): """Test successful parsing of API settings.""" # Arrange settings = { @@ -1428,7 +1478,7 @@ class TestExternalDatasetServiceCreateDataset: """Test create_external_dataset operations.""" @patch("services.external_knowledge_service.db") - def test_create_external_dataset_success_full(self, mock_db, factory): + def test_create_external_dataset_success_full(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test successful creation of external dataset with all fields.""" # Arrange tenant_id = "tenant-123" @@ -1457,7 +1507,9 @@ class TestExternalDatasetServiceCreateDataset: mock_db.session.commit.assert_called_once() @patch("services.external_knowledge_service.db") - def test_create_external_dataset_duplicate_name_error(self, mock_db, factory): + def test_create_external_dataset_duplicate_name_error( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when dataset name already exists.""" # Arrange existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset") @@ -1471,7 +1523,7 @@ class TestExternalDatasetServiceCreateDataset: ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args) @patch("services.external_knowledge_service.db") - def test_create_external_dataset_api_not_found_error(self, mock_db, factory): + def test_create_external_dataset_api_not_found_error(self, mock_db, factory: ExternalDatasetServiceTestDataFactory): """Test error when external knowledge API is not found.""" # Arrange mock_db.session.scalar.side_effect = [None, None] @@ -1483,7 +1535,9 @@ class TestExternalDatasetServiceCreateDataset: ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args) @patch("services.external_knowledge_service.db") - def test_create_external_dataset_missing_knowledge_id_error(self, mock_db, factory): + def test_create_external_dataset_missing_knowledge_id_error( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when external_knowledge_id is missing.""" # Arrange api = factory.create_external_knowledge_api_mock() @@ -1497,7 +1551,9 @@ class TestExternalDatasetServiceCreateDataset: ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args) @patch("services.external_knowledge_service.db") - def test_create_external_dataset_missing_api_id_error(self, mock_db, factory): + def test_create_external_dataset_missing_api_id_error( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when external_knowledge_api_id is missing.""" # Arrange api = factory.create_external_knowledge_api_mock() @@ -1516,7 +1572,9 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_success_with_results(self, mock_db, mock_process, factory): + def test_fetch_external_knowledge_retrieval_success_with_results( + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory + ): """Test successful external knowledge retrieval with results.""" # Arrange tenant_id = "tenant-123" @@ -1553,7 +1611,9 @@ class TestExternalDatasetServiceFetchRetrieval: assert result[1]["score"] == 0.8 @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory): + def test_fetch_external_knowledge_retrieval_binding_not_found_error( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when external knowledge binding is not found.""" # Arrange mock_db.session.scalar.return_value = None @@ -1563,7 +1623,9 @@ class TestExternalDatasetServiceFetchRetrieval: ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {}) @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory): + def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error( + self, mock_db, factory: ExternalDatasetServiceTestDataFactory + ): """Test error when a binding points to an API template outside the dataset tenant.""" # Arrange binding = factory.create_external_knowledge_binding_mock() @@ -1575,7 +1637,9 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory): + def test_fetch_external_knowledge_retrieval_empty_results( + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory + ): """Test retrieval with empty results.""" # Arrange binding = factory.create_external_knowledge_binding_mock() @@ -1598,7 +1662,9 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_with_score_threshold(self, mock_db, mock_process, factory): + def test_fetch_external_knowledge_retrieval_with_score_threshold( + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory + ): """Test retrieval with score threshold enabled.""" # Arrange binding = factory.create_external_knowledge_binding_mock() @@ -1630,7 +1696,9 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory): + def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception( + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory + ): """Test that non-200 status code raises Exception with response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() @@ -1665,7 +1733,7 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") def test_fetch_external_knowledge_retrieval_various_error_status_codes( - self, mock_db, mock_process, factory, status_code, error_message + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory, status_code, error_message ): """Test that various error status codes raise exceptions with response text.""" # Arrange @@ -1690,7 +1758,9 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory): + def test_fetch_external_knowledge_retrieval_empty_response_text( + self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory + ): """Test exception with empty response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 6995410c5e6..4c4abbbb8ec 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -259,7 +259,7 @@ def test_resolve_form_inputs_uses_runtime_select_options( def test_submit_form_by_token_calls_repository_and_enqueue( - sample_form_record, mock_session_factory, mocker: MockerFixture + sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture ): session_factory, _ = mock_session_factory repo = MagicMock(spec=HumanInputFormSubmissionRepository) @@ -318,7 +318,7 @@ def test_submit_form_by_token_enqueues_agent_app_resume_for_conversation_form( def test_submit_form_by_token_skips_enqueue_for_delivery_test( - sample_form_record, mock_session_factory, mocker: MockerFixture + sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture ): session_factory, _ = mock_session_factory repo = MagicMock(spec=HumanInputFormSubmissionRepository) @@ -343,7 +343,7 @@ def test_submit_form_by_token_skips_enqueue_for_delivery_test( def test_submit_form_by_token_passes_submission_user_id( - sample_form_record, mock_session_factory, mocker: MockerFixture + sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture ): session_factory, _ = mock_session_factory repo = MagicMock(spec=HumanInputFormSubmissionRepository) @@ -528,7 +528,7 @@ def test_validate_human_input_submission_rejects_invalid_select_and_file_payload repo.mark_submitted.assert_not_called() -def test_form_properties(sample_form_record): +def test_form_properties(sample_form_record: HumanInputFormRecord): form = Form(sample_form_record) assert form.id == "form-id" assert form.workflow_run_id == "workflow-run-id" diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 005dcec886e..6588c8a8de6 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -90,7 +90,7 @@ class TestMessageServicePaginationByFirstId: return TestMessageServiceFactory() # Test 01: No user provided - def test_pagination_by_first_id_no_user(self, factory): + def test_pagination_by_first_id_no_user(self, factory: TestMessageServiceFactory): """Test pagination returns empty result when no user is provided.""" # Arrange app = factory.create_app_mock() @@ -111,7 +111,7 @@ class TestMessageServicePaginationByFirstId: assert result.has_more is False # Test 02: No conversation_id provided - def test_pagination_by_first_id_no_conversation(self, factory): + def test_pagination_by_first_id_no_conversation(self, factory: TestMessageServiceFactory): """Test pagination returns empty result when no conversation_id is provided.""" # Arrange app = factory.create_app_mock() @@ -137,7 +137,7 @@ class TestMessageServicePaginationByFirstId: @patch("services.message_service.db") @patch("services.message_service.ConversationService") def test_pagination_by_first_id_without_first_id_desc( - self, mock_conversation_service, mock_db, mock_create_repo, factory + self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory ): """Test basic pagination without first_id in descending order.""" # Arrange @@ -180,7 +180,7 @@ class TestMessageServicePaginationByFirstId: @patch("services.message_service.db") @patch("services.message_service.ConversationService") def test_pagination_by_first_id_without_first_id_asc( - self, mock_conversation_service, mock_db, mock_create_repo, factory + self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory ): """Test basic pagination without first_id in ascending order.""" # Arrange @@ -222,7 +222,9 @@ class TestMessageServicePaginationByFirstId: @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, mock_create_repo, factory): + def test_pagination_by_first_id_with_first_id( + self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory + ): """Test pagination with first_id to get messages before a specific message.""" # Arrange app = factory.create_app_mock() @@ -265,7 +267,9 @@ class TestMessageServicePaginationByFirstId: # Test 06: First message not found @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_first_message_not_exists(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_first_message_not_exists( + self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory + ): """Test error handling when first_id doesn't exist.""" # Arrange app = factory.create_app_mock() @@ -290,7 +294,9 @@ class TestMessageServicePaginationByFirstId: @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, mock_create_repo, factory): + def test_pagination_by_first_id_has_more_true( + self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory + ): """Test has_more flag is True when results exceed limit.""" # Arrange app = factory.create_app_mock() @@ -327,7 +333,9 @@ class TestMessageServicePaginationByFirstId: # Test 08: Empty conversation @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_empty_conversation(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_empty_conversation( + self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory + ): """Test pagination with conversation that has no messages.""" # Arrange app = factory.create_app_mock() @@ -370,7 +378,7 @@ class TestMessageServicePaginationByLastId: return TestMessageServiceFactory() # Test 09: No user provided - def test_pagination_by_last_id_no_user(self, factory): + def test_pagination_by_last_id_no_user(self, factory: TestMessageServiceFactory): """Test pagination returns empty result when no user is provided.""" # Arrange app = factory.create_app_mock() @@ -391,7 +399,7 @@ class TestMessageServicePaginationByLastId: # Test 10: Basic pagination without last_id @patch("services.message_service.db") - def test_pagination_by_last_id_without_last_id(self, mock_db, factory): + def test_pagination_by_last_id_without_last_id(self, mock_db, factory: TestMessageServiceFactory): """Test basic pagination without last_id.""" # Arrange app = factory.create_app_mock() @@ -422,7 +430,7 @@ class TestMessageServicePaginationByLastId: # Test 11: Pagination with last_id @patch("services.message_service.db") - def test_pagination_by_last_id_with_last_id(self, mock_db, factory): + def test_pagination_by_last_id_with_last_id(self, mock_db, factory: TestMessageServiceFactory): """Test pagination with last_id to get messages after a specific message.""" # Arrange app = factory.create_app_mock() @@ -459,7 +467,7 @@ class TestMessageServicePaginationByLastId: # Test 12: Last message not found @patch("services.message_service.db") - def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory): + def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory: TestMessageServiceFactory): """Test error handling when last_id doesn't exist.""" # Arrange app = factory.create_app_mock() @@ -479,7 +487,9 @@ class TestMessageServicePaginationByLastId: # Test 13: Pagination with conversation_id filter @patch("services.message_service.ConversationService") @patch("services.message_service.db") - def test_pagination_by_last_id_with_conversation_filter(self, mock_db, mock_conversation_service, factory): + def test_pagination_by_last_id_with_conversation_filter( + self, mock_db, mock_conversation_service, factory: TestMessageServiceFactory + ): """Test pagination filtered by conversation_id.""" # Arrange app = factory.create_app_mock() @@ -515,7 +525,7 @@ class TestMessageServicePaginationByLastId: # Test 14: Pagination with include_ids filter @patch("services.message_service.db") - def test_pagination_by_last_id_with_include_ids(self, mock_db, factory): + def test_pagination_by_last_id_with_include_ids(self, mock_db, factory: TestMessageServiceFactory): """Test pagination filtered by include_ids.""" # Arrange app = factory.create_app_mock() @@ -545,7 +555,7 @@ class TestMessageServicePaginationByLastId: # Test 15: Has_more flag when results exceed limit @patch("services.message_service.db") - def test_pagination_by_last_id_has_more_true(self, mock_db, factory): + def test_pagination_by_last_id_has_more_true(self, mock_db, factory: TestMessageServiceFactory): """Test has_more flag is True when results exceed limit.""" # Arrange app = factory.create_app_mock() @@ -592,7 +602,7 @@ class TestMessageServiceUtilities: # Test 17: attach_message_extra_contents with messages @patch("services.message_service._create_execution_extra_content_repository") - def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory): + def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory: TestMessageServiceFactory): """Test attach_message_extra_contents correctly attaches content.""" # Arrange messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")] @@ -618,7 +628,9 @@ class TestMessageServiceUtilities: # Test 18: attach_message_extra_contents with index out of bounds @patch("services.message_service._create_execution_extra_content_repository") - def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory): + def test_attach_message_extra_contents_index_out_of_bounds( + self, mock_create_repo, factory: TestMessageServiceFactory + ): """Test attach_message_extra_contents handles missing content lists.""" # Arrange messages = [factory.create_message_mock(message_id="msg-1")] @@ -659,7 +671,7 @@ class TestMessageServiceGetMessage: # Test 20: get_message success for EndUser @patch("services.message_service.db") - def test_get_message_end_user_success(self, mock_db, factory): + def test_get_message_end_user_success(self, mock_db, factory: TestMessageServiceFactory): """Test get_message returns message for EndUser.""" # Arrange app = factory.create_app_mock() @@ -676,7 +688,7 @@ class TestMessageServiceGetMessage: # Test 21: get_message success for Account (Admin) @patch("services.message_service.db") - def test_get_message_account_success(self, mock_db, factory): + def test_get_message_account_success(self, mock_db, factory: TestMessageServiceFactory): """Test get_message returns message for Account.""" # Arrange from models import Account @@ -696,7 +708,7 @@ class TestMessageServiceGetMessage: # Test 22: get_message not found @patch("services.message_service.db") - def test_get_message_not_found(self, mock_db, factory): + def test_get_message_not_found(self, mock_db, factory: TestMessageServiceFactory): """Test get_message raises MessageNotExistsError when not found.""" # Arrange app = factory.create_app_mock() @@ -720,7 +732,7 @@ class TestMessageServiceFeedback: # Test 23: create_feedback - new feedback for EndUser @patch("services.message_service.db") @patch.object(MessageService, "get_message") - def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory): + def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory: TestMessageServiceFactory): """Test creating new feedback for an end user.""" # Arrange app = factory.create_app_mock() @@ -748,7 +760,7 @@ class TestMessageServiceFeedback: # Test 24: create_feedback - update feedback for Account @patch("services.message_service.db") @patch.object(MessageService, "get_message") - def test_create_feedback_update_account(self, mock_get_message, mock_db, factory): + def test_create_feedback_update_account(self, mock_get_message, mock_db, factory: TestMessageServiceFactory): """Test updating existing feedback for an account.""" # Arrange from models import Account, MessageFeedback @@ -779,7 +791,7 @@ class TestMessageServiceFeedback: # Test 25: create_feedback - delete feedback (rating is None) @patch("services.message_service.db") @patch.object(MessageService, "get_message") - def test_create_feedback_delete(self, mock_get_message, mock_db, factory): + def test_create_feedback_delete(self, mock_get_message, mock_db, factory: TestMessageServiceFactory): """Test deleting feedback by passing rating=None.""" # Arrange app = factory.create_app_mock() @@ -805,7 +817,7 @@ class TestMessageServiceFeedback: # Test 26: get_all_messages_feedbacks @patch("services.message_service.db") - def test_get_all_messages_feedbacks(self, mock_db, factory): + def test_get_all_messages_feedbacks(self, mock_db, factory: TestMessageServiceFactory): """Test get_all_messages_feedbacks returns list of dicts.""" # Arrange app = factory.create_app_mock() @@ -830,7 +842,7 @@ class TestMessageServiceSuggestedQuestions: return TestMessageServiceFactory() # Test 27: get_suggested_questions_after_answer - user is None - def test_get_suggested_questions_user_none(self, factory): + def test_get_suggested_questions_user_none(self, factory: TestMessageServiceFactory): app = factory.create_app_mock() with pytest.raises(ValueError, match="user cannot be None"): MessageService.get_suggested_questions_after_answer( @@ -856,7 +868,7 @@ class TestMessageServiceSuggestedQuestions: mock_config_manager, mock_workflow_service, mock_model_manager, - factory, + factory: TestMessageServiceFactory, ): """Test successful suggested questions generation in Advanced Chat mode.""" from core.app.entities.app_invoke_entities import InvokeFrom @@ -896,14 +908,14 @@ class TestMessageServiceSuggestedQuestions: @patch("services.message_service.ConversationService") def test_get_suggested_questions_chat_app_success( self, - mock_conversation_service, - mock_get_message, - mock_trace_manager, - mock_llm_gen, - mock_memory, - mock_model_manager, - mock_db, - factory, + mock_conversation_service: MagicMock, + mock_get_message: MagicMock, + mock_trace_manager: MagicMock, + mock_llm_gen: MagicMock, + mock_memory: MagicMock, + mock_model_manager: MagicMock, + mock_db: MagicMock, + factory: TestMessageServiceFactory, ): """Test successful suggested questions generation in basic Chat mode.""" # Arrange @@ -942,14 +954,14 @@ class TestMessageServiceSuggestedQuestions: @patch("services.message_service.ConversationService") def test_get_suggested_questions_chat_app_uses_frontend_model_and_prompt( self, - mock_conversation_service, - mock_get_message, - mock_trace_manager, - mock_llm_gen, - mock_memory, - mock_model_manager, - mock_db, - factory, + mock_conversation_service: MagicMock, + mock_get_message: MagicMock, + mock_trace_manager: MagicMock, + mock_llm_gen: MagicMock, + mock_memory: MagicMock, + mock_model_manager: MagicMock, + mock_db: MagicMock, + factory: TestMessageServiceFactory, ): """Test suggested question generation uses frontend configured model and prompt.""" from core.app.entities.app_invoke_entities import InvokeFrom @@ -1015,14 +1027,14 @@ class TestMessageServiceSuggestedQuestions: @patch("services.message_service.ConversationService") def test_get_suggested_questions_chat_app_invalid_frontend_model_fallback_to_default( self, - mock_conversation_service, - mock_get_message, - mock_trace_manager, - mock_llm_gen, - mock_memory, - mock_model_manager, - mock_db, - factory, + mock_conversation_service: MagicMock, + mock_get_message: MagicMock, + mock_trace_manager: MagicMock, + mock_llm_gen: MagicMock, + mock_memory: MagicMock, + mock_model_manager: MagicMock, + mock_db: MagicMock, + factory: TestMessageServiceFactory, ): """Test invalid frontend configured model falls back to tenant default model.""" app = factory.create_app_mock(mode=AppMode.CHAT) @@ -1066,14 +1078,14 @@ class TestMessageServiceSuggestedQuestions: @patch("services.message_service.ConversationService") def test_get_suggested_questions_chat_app_uses_compatible_override_model_config( self, - mock_conversation_service, - mock_get_message, - mock_trace_manager, - mock_llm_gen, - mock_memory, - mock_model_manager, - mock_db, - factory, + mock_conversation_service: MagicMock, + mock_get_message: MagicMock, + mock_trace_manager: MagicMock, + mock_llm_gen: MagicMock, + mock_memory: MagicMock, + mock_model_manager: MagicMock, + mock_db: MagicMock, + factory: TestMessageServiceFactory, ): """Test legacy override configs are normalized before suggested questions reads them.""" app = factory.create_app_mock(mode=AppMode.CHAT) @@ -1174,7 +1186,12 @@ class TestMessageServiceSuggestedQuestions: @patch.object(MessageService, "get_message") @patch("services.message_service.ConversationService") def test_get_suggested_questions_disabled_error( - self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory + self, + mock_conversation_service, + mock_get_message, + mock_config_manager, + mock_workflow_service, + factory: TestMessageServiceFactory, ): """Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled.""" # Arrange diff --git a/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py index 733f0d17ca2..305045cb6ee 100644 --- a/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py +++ b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py @@ -1,6 +1,6 @@ import json import logging -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest @@ -130,7 +130,7 @@ class TestRagPipelineTaskProxy: assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3" @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") - def test_features_property(self, mock_feature_service): + def test_features_property(self, mock_feature_service: MagicMock): """Test cached_property features.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features() @@ -149,7 +149,7 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_upload_invoke_entities(self, mock_db, mock_file_service_class): + def test_upload_invoke_entities(self, mock_db: MagicMock, mock_file_service_class: MagicMock): """Test _upload_invoke_entities method.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -181,7 +181,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class): + def test_upload_invoke_entities_with_multiple_entities( + self, mock_db: MagicMock, mock_file_service_class: MagicMock + ): """Test _upload_invoke_entities method with multiple entities.""" # Arrange entities = [ @@ -209,7 +211,7 @@ class TestRagPipelineTaskProxy: assert parsed_json[1]["pipeline_id"] == "pipeline-2" @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") - def test_send_to_direct_queue(self, mock_task): + def test_send_to_direct_queue(self, mock_task: MagicMock): """Test _send_to_direct_queue method.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -229,7 +231,7 @@ class TestRagPipelineTaskProxy: ) @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") - def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task: MagicMock): """Test _send_to_tenant_queue when task key exists.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -248,7 +250,7 @@ class TestRagPipelineTaskProxy: mock_task.delay.assert_not_called() @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") - def test_send_to_tenant_queue_without_task_key(self, mock_task): + def test_send_to_tenant_queue_without_task_key(self, mock_task: MagicMock): """Test _send_to_tenant_queue when no task key exists.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -271,7 +273,7 @@ class TestRagPipelineTaskProxy: proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") - def test_send_to_default_tenant_queue(self, mock_task): + def test_send_to_default_tenant_queue(self, mock_task: MagicMock): """Test _send_to_default_tenant_queue method.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -285,7 +287,7 @@ class TestRagPipelineTaskProxy: proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") - def test_send_to_priority_tenant_queue(self, mock_task): + def test_send_to_priority_tenant_queue(self, mock_task: MagicMock): """Test _send_to_priority_tenant_queue method.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -299,7 +301,7 @@ class TestRagPipelineTaskProxy: proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") - def test_send_to_priority_direct_queue(self, mock_task): + def test_send_to_priority_direct_queue(self, mock_task: MagicMock): """Test _send_to_priority_direct_queue method.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -315,7 +317,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service): + def test_dispatch_with_billing_enabled_sandbox_plan( + self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock + ): """Test _dispatch method when billing is enabled with sandbox plan.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( @@ -365,7 +369,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service): + def test_dispatch_with_billing_disabled( + self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock + ): """Test _dispatch method when billing is disabled.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) @@ -386,7 +392,7 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class): + def test_dispatch_with_empty_upload_file_id(self, mock_db: MagicMock, mock_file_service_class: MagicMock): """Test _dispatch method when upload_file_id is empty.""" # Arrange proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() @@ -404,7 +410,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service): + def test_dispatch_edge_case_empty_plan( + self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock + ): """Test _dispatch method with empty plan string.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") @@ -426,7 +434,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service): + def test_dispatch_edge_case_none_plan( + self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock + ): """Test _dispatch method with None plan.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) @@ -448,7 +458,9 @@ class TestRagPipelineTaskProxy: @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") - def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service): + def test_delay_method( + self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock + ): """Test delay method integration.""" # Arrange mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( diff --git a/api/tests/unit_tests/services/test_snippet_dsl_service.py b/api/tests/unit_tests/services/test_snippet_dsl_service.py index 4866c570fdf..c155d3f8330 100644 --- a/api/tests/unit_tests/services/test_snippet_dsl_service.py +++ b/api/tests/unit_tests/services/test_snippet_dsl_service.py @@ -73,7 +73,7 @@ def test_import_snippet_rejects_invalid_yaml_url_scheme() -> None: assert result.error == "Invalid URL scheme, only http and https are allowed" -def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) -> None: +def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch: pytest.MonkeyPatch) -> None: service = SnippetDslService(session=SimpleNamespace()) monkeypatch.setattr( "services.snippet_dsl_service.ssrf_proxy.get", @@ -90,7 +90,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) -> assert result.error == "Failed to fetch YAML from URL: 404" -def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None: +def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch: pytest.MonkeyPatch) -> None: service = SnippetDslService(session=SimpleNamespace()) monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3) monkeypatch.setattr( @@ -108,7 +108,7 @@ def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None: assert "YAML content size exceeds maximum limit" in result.error -def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) -> None: +def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch: pytest.MonkeyPatch) -> None: service = SnippetDslService(session=SimpleNamespace()) monkeypatch.setattr( "services.snippet_dsl_service.ssrf_proxy.get", @@ -125,7 +125,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) - assert result.error == "Failed to fetch YAML from URL: network down" -def test_import_snippet_rejects_oversized_yaml_content(monkeypatch) -> None: +def test_import_snippet_rejects_oversized_yaml_content(monkeypatch: pytest.MonkeyPatch) -> None: service = SnippetDslService(session=SimpleNamespace()) monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3) diff --git a/api/tests/unit_tests/services/test_snippet_service.py b/api/tests/unit_tests/services/test_snippet_service.py index 008b6cfa418..16e1e58baae 100644 --- a/api/tests/unit_tests/services/test_snippet_service.py +++ b/api/tests/unit_tests/services/test_snippet_service.py @@ -529,7 +529,7 @@ def test_delete_snippet_removes_related_records() -> None: session.delete.assert_called_once_with(snippet) -def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> None: +def test_delete_draft_variable_files_removes_storage_objects(monkeypatch: pytest.MonkeyPatch) -> None: from extensions.ext_storage import storage snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") @@ -554,7 +554,7 @@ def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> Non assert "workflow_draft_variable_files" in executed_sql -def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch) -> None: +def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch: pytest.MonkeyPatch) -> None: from configs import dify_config snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") diff --git a/api/tests/unit_tests/services/test_workflow_generator_service.py b/api/tests/unit_tests/services/test_workflow_generator_service.py index e1d3a488be3..184cd5f29e5 100644 --- a/api/tests/unit_tests/services/test_workflow_generator_service.py +++ b/api/tests/unit_tests/services/test_workflow_generator_service.py @@ -30,10 +30,10 @@ class TestWorkflowGeneratorService: @patch("services.workflow_generator_service.format_tool_catalogue") def test_forwards_model_instance_and_catalogue_text_to_generator( self, - mock_format_catalogue, - mock_build_catalogue, - mock_model_manager, - mock_workflow_generator, + mock_format_catalogue: MagicMock, + mock_build_catalogue: MagicMock, + mock_model_manager: MagicMock, + mock_workflow_generator: MagicMock, ): """Happy path: model_instance + catalogue text + payload flow through.""" # Arrange @@ -110,10 +110,10 @@ class TestWorkflowGeneratorService: @patch("services.workflow_generator_service.format_tool_catalogue") def test_defaults_ideal_output_to_empty_string( self, - mock_format_catalogue, - mock_build_catalogue, - mock_model_manager, - mock_workflow_generator, + mock_format_catalogue: MagicMock, + mock_build_catalogue: MagicMock, + mock_model_manager: MagicMock, + mock_workflow_generator: MagicMock, ): """Callers can omit ideal_output; the runner should still receive "".""" mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock() @@ -142,10 +142,10 @@ class TestWorkflowGeneratorService: @patch("services.workflow_generator_service.format_tool_catalogue") def test_forwards_current_graph_for_refine( self, - mock_format_catalogue, - mock_build_catalogue, - mock_model_manager, - mock_workflow_generator, + mock_format_catalogue: MagicMock, + mock_build_catalogue: MagicMock, + mock_model_manager: MagicMock, + mock_workflow_generator: MagicMock, ): """The cmd+k `/refine` path passes the existing draft graph through to the runner.""" mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock() @@ -175,10 +175,10 @@ class TestWorkflowGeneratorService: @patch("services.workflow_generator_service.format_tool_catalogue") def test_defaults_current_graph_to_none_for_create( self, - mock_format_catalogue, - mock_build_catalogue, - mock_model_manager, - mock_workflow_generator, + mock_format_catalogue: MagicMock, + mock_build_catalogue: MagicMock, + mock_model_manager: MagicMock, + mock_workflow_generator: MagicMock, ): """Omitting current_graph (the `/create` path) forwards None to the runner.""" mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock() diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index dbd8f05098a..dc94dbdb3d0 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -1,3 +1,4 @@ +from pathlib import Path from textwrap import dedent import pytest @@ -6,7 +7,7 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit @pytest.fixture -def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: +def prepare_example_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str: monkeypatch.chdir(tmp_path) tmp_path.joinpath("example_positions.yaml").write_text( dedent( @@ -25,7 +26,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: @pytest.fixture -def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: +def prepare_empty_commented_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str: monkeypatch.chdir(tmp_path) tmp_path.joinpath("example_positions_all_commented.yaml").write_text( dedent( diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 9e2b0659c04..fd4e4a3664d 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from textwrap import dedent import pytest @@ -11,7 +12,7 @@ NON_EXISTING_YAML_FILE = "non_existing_file.yaml" @pytest.fixture -def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: +def prepare_example_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) file_path.write_text( @@ -34,7 +35,7 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: @pytest.fixture -def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: +def prepare_invalid_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(INVALID_YAML_FILE) file_path.write_text(