chore: add Type to test (#37191)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-20 01:44:20 +09:00 committed by GitHub
parent 9eca75c7fc
commit fae607e2fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 801 additions and 575 deletions

View File

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload, over
from uuid import UUID
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask import Request, Response, stream_with_context
from flask_restx import fields
from pydantic import BaseModel, ConfigDict, TypeAdapter, with_config
from pydantic.functional_validators import AfterValidator
@ -167,19 +167,6 @@ def build_avatar_url(avatar: str | None) -> str | None:
return file_helpers.get_signed_file_url(avatar)
class AvatarUrlField(fields.Raw):
@override
def output(self, key, obj, **kwargs):
if obj is None:
return None
from models import Account
if isinstance(obj, Account) and obj.avatar is not None:
return build_avatar_url(obj.avatar)
return None
class TimestampField(fields.Raw):
@override
def schema(self) -> dict[str, object]:
@ -397,7 +384,7 @@ def generate_string(n):
return result
def extract_remote_ip(request) -> str:
def extract_remote_ip(request: Request) -> str:
if request.headers.get("CF-Connecting-IP"):
return cast(str, request.headers.get("CF-Connecting-IP"))
elif request.headers.getlist("X-Forwarded-For"):

View File

@ -42,7 +42,9 @@ def trace_client_factory():
class TestTraceClient:
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
def test_init(
self, mock_gethostname: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
):
mock_gethostname.return_value = "test-host"
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@ -56,7 +58,7 @@ class TestTraceClient:
assert client.done is True
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export(self, mock_exporter_class, trace_client_factory):
def test_export(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
spans = [MagicMock(spec=ReadableSpan)]
@ -65,7 +67,9 @@ class TestTraceClient:
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
def test_api_check_success(
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
):
mock_response = MagicMock()
mock_response.status_code = 405
mock_head.return_value = mock_response
@ -75,7 +79,9 @@ class TestTraceClient:
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
def test_api_check_failure_status(
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
):
mock_response = MagicMock()
mock_response.status_code = 500
mock_head.return_value = mock_response
@ -85,7 +91,9 @@ class TestTraceClient:
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
def test_api_check_exception(
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
):
mock_head.side_effect = httpx.RequestError("Connection error")
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@ -93,12 +101,12 @@ class TestTraceClient:
client.api_check()
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
def test_get_project_url(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_add_span(self, mock_exporter_class, trace_client_factory):
def test_add_span(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
client = trace_client_factory(
service_name="test-service",
endpoint="http://test-endpoint",
@ -135,7 +143,9 @@ class TestTraceClient:
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
def test_add_span_queue_full(
self, mock_logger: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
span_data = SpanData(
@ -159,7 +169,7 @@ class TestTraceClient:
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
def test_export_batch_error(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
@ -172,13 +182,13 @@ class TestTraceClient:
mock_logger.warning.assert_called()
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
def test_worker_loop(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
# We need to test the wait timeout in _worker
# But _worker runs in a thread. Let's mock condition.wait.
client = trace_client_factory(
service_name="test-service",
endpoint="http://test-endpoint",
schedule_delay_sec=0.1,
schedule_delay_sec=1,
)
with patch.object(client.condition, "wait") as mock_wait:
@ -189,7 +199,7 @@ class TestTraceClient:
assert mock_wait.called or client.done
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
def test_shutdown_flushes(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")

View File

@ -456,7 +456,7 @@ class TestPhoenixParentSpanBridgeHelpers:
assert error.parent_node_execution_id == "outer-node-execution-1"
assert "outer-node-execution-1" in str(error)
def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch):
def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch: pytest.MonkeyPatch):
mock_redis = MagicMock()
mock_redis.get.return_value = '{"tracestate": "vendor=value"}'
monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis)

View File

@ -656,7 +656,7 @@ def _patch_workflow_trace_deps(monkeypatch, trace_instance):
trace_instance.add_run = MagicMock()
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch):
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Chatflow with external trace_id: LangSmith trace_id must be message_id, not external."""
trace_info = _make_workflow_trace_info(
message_id="msg-abc",
@ -677,7 +677,7 @@ def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypa
assert trace_info.metadata.get("external_trace_id") == "external-999"
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch):
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id."""
trace_info = _make_workflow_trace_info(
message_id=None,

View File

@ -13,7 +13,9 @@ from extensions.ext_database import db
from models import App
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
def test_run_chat_dispatches_to_chat_handler(
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
):
captured = {}
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
@ -78,7 +80,9 @@ def app_with_mode(flask_app: Flask, workspace_account):
db.session.commit()
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
def test_run_chat_without_query_returns_422(
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
@ -89,7 +93,9 @@ def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_wor
assert b"query_required_for_chat" in res.data
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
def test_run_completion_dispatches_to_completion_handler(
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
):
app = app_with_mode("completion")
captured: dict = {}
@ -119,7 +125,9 @@ def test_run_completion_dispatches_to_completion_handler(flask_app, account_toke
assert captured["mode"] == "completion"
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
def test_run_workflow_with_query_returns_422(
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
):
app = app_with_mode("workflow")
client = flask_app.test_client()
res = client.post(
@ -131,7 +139,9 @@ def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_
assert b"query_not_supported_for_workflow" in res.data
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
def test_run_workflow_no_query_dispatches_to_workflow_handler(
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
):
app = app_with_mode("workflow")
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
@ -154,7 +164,9 @@ def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account
assert body["workflow_run_id"] == "wfr"
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
def test_run_unsupported_mode_returns_422(
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
):
app = app_with_mode("channel")
client = flask_app.test_client()
res = client.post(
@ -166,7 +178,7 @@ def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mod
assert b"mode_not_runnable" in res.data
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
def test_run_without_bearer_returns_401(flask_app: Flask, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
@ -175,7 +187,9 @@ def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
assert res.status_code == 401
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
def test_run_with_insufficient_scope_returns_403(
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
):
"""Stub the authenticator to return an AuthContext with empty scopes."""
from libs import oauth_bearer
@ -198,7 +212,7 @@ def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_i
assert res.status_code == 403
def test_run_with_unknown_app_returns_404(flask_app, account_token):
def test_run_with_unknown_app_returns_404(flask_app: Flask, account_token):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{uuid.uuid4()}/run",
@ -208,7 +222,9 @@ def test_run_with_unknown_app_returns_404(flask_app, account_token):
assert res.status_code == 404
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
def test_run_streaming_returns_event_stream(
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
):
def _stream() -> Generator[str, None, None]:
yield 'event: message\ndata: {"x": 1}\n\n'
@ -228,7 +244,7 @@ def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_wor
assert b"event: message" in res.data
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
def test_run_without_inputs_returns_422(flask_app: Flask, account_token, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",

View File

@ -4,6 +4,8 @@ import uuid
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_manager import ModelInstance
@ -91,7 +93,7 @@ def init_llm_node(config: dict) -> LLMNode:
return node
def _mock_db_session_close(monkeypatch) -> None:
def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(db.session, "close", MagicMock())

View File

@ -3,6 +3,8 @@ import time
import uuid
from unittest.mock import MagicMock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.model_manager import ModelInstance
from core.workflow.node_runtime import DifyPromptMessageSerializer
@ -83,11 +85,11 @@ def init_parameter_extractor_node(config: dict, memory=None):
return node
def _mock_db_session_close(monkeypatch) -> None:
def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(db.session, "close", MagicMock())
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
"""
Test function calling for parameter extractor.
"""
@ -128,7 +130,7 @@ def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
assert result.outputs.get("__reason") == None
def test_instructions(setup_model_mock, monkeypatch):
def test_instructions(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
"""
Test chat parameter extractor.
"""
@ -178,7 +180,7 @@ def test_instructions(setup_model_mock, monkeypatch):
assert "what's the weather in SF" in prompt.get("text")
def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
"""
Test chat parameter extractor.
"""
@ -229,7 +231,7 @@ def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
assert '<structure>\n{"type": "object"' in prompt.get("text")
def test_completion_parameter_extractor(setup_model_mock, monkeypatch):
def test_completion_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
"""
Test completion parameter extractor.
"""
@ -354,7 +356,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
"""
Test chat parameter extractor with memory.
"""

View File

@ -517,11 +517,11 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.TenantService")
def test_should_handle_account_generation_scenarios(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
mock_tenant_service: MagicMock,
mock_account_service: MagicMock,
mock_register_service: MagicMock,
mock_feature_service: MagicMock,
mock_get_account: MagicMock,
app: Flask,
user_info: OAuthUserInfo,
mock_account,
@ -562,11 +562,11 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_lowercase_email(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
mock_tenant_service: MagicMock,
mock_account_service: MagicMock,
mock_register_service: MagicMock,
mock_feature_service: MagicMock,
mock_get_account: MagicMock,
app: Flask,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
@ -593,11 +593,11 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_browser_timezone(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
mock_tenant_service: MagicMock,
mock_account_service: MagicMock,
mock_register_service: MagicMock,
mock_feature_service: MagicMock,
mock_get_account: MagicMock,
app: Flask,
user_info: OAuthUserInfo,
):
@ -624,11 +624,11 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_state_language(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
mock_tenant_service: MagicMock,
mock_account_service: MagicMock,
mock_register_service: MagicMock,
mock_feature_service: MagicMock,
mock_get_account: MagicMock,
app: Flask,
user_info: OAuthUserInfo,
):

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from uuid import uuid4
import pytest
@ -83,8 +83,8 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(
self,
mock_encrypter,
mock_factory,
mock_encrypter: MagicMock,
mock_factory: MagicMock,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id,
@ -107,7 +107,12 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(
self, mock_factory, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, mock_args
self,
mock_factory: MagicMock,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id,
mock_args,
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
@ -123,8 +128,8 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(
self,
mock_encrypter,
mock_factory,
mock_encrypter: MagicMock,
mock_factory: MagicMock,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id,
@ -289,7 +294,7 @@ class TestApiKeyAuthService:
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args):
def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args):
mock_factory.side_effect = Exception("Factory error")
with pytest.raises(Exception, match="Factory error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)

View File

@ -47,8 +47,8 @@ class TestGetDynamicSelectOptionsTool:
@patch("services.plugin.plugin_parameter_service.ToolManager")
def test_fetches_credentials_with_credential_id(
self,
mock_tool_mgr,
mock_encrypter_fn,
mock_tool_mgr: MagicMock,
mock_encrypter_fn: MagicMock,
mock_client_cls,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
@ -90,8 +90,8 @@ class TestGetDynamicSelectOptionsTool:
@patch("services.plugin.plugin_parameter_service.ToolManager")
def test_raises_when_tool_provider_not_found(
self,
mock_tool_mgr,
mock_encrypter_fn,
mock_tool_mgr: MagicMock,
mock_encrypter_fn: MagicMock,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
):

View File

@ -258,7 +258,7 @@ class TestUpgradePluginWithMarketplace:
class TestUpgradePluginWithGithub:
@patch("core.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller")
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls: MagicMock, mock_fs: MagicMock):
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.upgrade_plugin.return_value = MagicMock()
@ -273,7 +273,7 @@ class TestUpgradePluginWithGithub:
class TestUploadPkg:
@patch("core.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller")
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
def test_runs_permission_and_scope_checks(self, mock_installer_cls: MagicMock, mock_fs: MagicMock):
mock_fs.get_system_features.return_value = _make_features()
upload_resp = MagicMock()
upload_resp.verification = None
@ -318,7 +318,7 @@ class TestInstallFromMarketplacePkg:
@patch("core.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller")
@patch("core.plugin.plugin_service.dify_config")
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls: MagicMock, mock_fs: MagicMock):
mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value

View File

@ -5,6 +5,8 @@ import json
import sys
from pathlib import Path
import pytest
def _load_generate_swagger_markdown_docs_module():
api_dir = Path(__file__).resolve().parents[3]
@ -20,7 +22,9 @@ def _load_generate_swagger_markdown_docs_module():
return module
def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(tmp_path, monkeypatch):
def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
module = _load_generate_swagger_markdown_docs_module()
openapi_dir = tmp_path / "openapi"
markdown_dir = tmp_path / "markdown"
@ -69,7 +73,9 @@ def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_con
assert "FastOpenAPI Preview" not in (markdown_dir / "service-openapi.md").read_text(encoding="utf-8")
def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(tmp_path, monkeypatch):
def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
module = _load_generate_swagger_markdown_docs_module()
swagger_dir = tmp_path / "swagger"
markdown_dir = tmp_path / "markdown"
@ -105,7 +111,7 @@ def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagg
assert not list(swagger_dir.glob("*.json"))
def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path):
def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path: Path):
module = _load_generate_swagger_markdown_docs_module()
spec_path = tmp_path / "console-openapi.json"
spec_path.write_text(
@ -239,7 +245,7 @@ def test_patch_union_schema_markdown_ignores_specs_without_schemas(tmp_path):
assert module._patch_union_schema_markdown("unchanged", spec_path) == "unchanged"
def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path: Path):
module = _load_generate_swagger_markdown_docs_module()
spec_path = tmp_path / "console-openapi.json"
spec_path.write_text(
@ -285,7 +291,7 @@ def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
assert module._patch_union_schema_markdown("#### BrokenUnion\n", spec_path) == "#### BrokenUnion\n"
def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monkeypatch):
def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
module = _load_generate_swagger_markdown_docs_module()
spec_path = tmp_path / "console-openapi.json"
output_path = tmp_path / "console-openapi.md"

View File

@ -3,11 +3,12 @@ from __future__ import annotations
from datetime import UTC, datetime
import pytest
from flask import Flask
from controllers.console.app import message as message_module
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_chat_messages_query_valid(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test valid ChatMessagesQuery with all fields."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
@ -17,14 +18,14 @@ def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None
assert query.limit == 50
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_chat_messages_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery with defaults."""
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
assert query.first_id is None
assert query.limit == 20
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_chat_messages_query_empty_first_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery converts empty first_id to None."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
@ -33,7 +34,7 @@ def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch
assert query.first_id is None
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_message_feedback_payload_valid_like(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with like rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
@ -44,7 +45,7 @@ def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatc
assert payload.content == "Good answer"
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_message_feedback_payload_valid_dislike(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with dislike rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
@ -53,69 +54,69 @@ def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyP
assert payload.rating == "dislike"
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_message_feedback_payload_no_rating(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload without rating."""
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.rating is None
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with default format."""
query = message_module.FeedbackExportQuery()
assert query.format == "csv"
assert query.from_source is None
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_json_format(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with JSON format."""
query = message_module.FeedbackExportQuery(format="json")
assert query.format == "json"
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_has_comment_true(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as true string."""
query = message_module.FeedbackExportQuery(has_comment="true")
assert query.has_comment is True
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_has_comment_false(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as false string."""
query = message_module.FeedbackExportQuery(has_comment="false")
assert query.has_comment is False
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_has_comment_1(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 1."""
query = message_module.FeedbackExportQuery(has_comment="1")
assert query.has_comment is True
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_has_comment_0(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 0."""
query = message_module.FeedbackExportQuery(has_comment="0")
assert query.has_comment is False
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_feedback_export_query_rating_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with rating filter."""
query = message_module.FeedbackExportQuery(rating="like")
assert query.rating == "like"
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_annotation_count_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test AnnotationCountResponse creation."""
response = message_module.AnnotationCountResponse(count=10)
assert response.count == 10
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_suggested_questions_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test SuggestedQuestionsResponse creation."""
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"
def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_message_detail_response_normalizes_aliases_and_timestamp(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageDetailResponse normalizes alias fields and datetime timestamps."""
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = message_module.MessageDetailResponse.model_validate(

View File

@ -5,6 +5,7 @@ from inspect import unwrap
from types import SimpleNamespace
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest
from controllers.console.app import statistic as statistic_module
@ -38,7 +39,7 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_message_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = unwrap(api.get)
@ -52,7 +53,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_conversation_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = unwrap(api.get)
@ -66,7 +67,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_token_cost_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = unwrap(api.get)
@ -84,7 +85,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
assert data["data"][0]["total_price"] == 0.25
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_terminals_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTerminalsStatistic()
method = unwrap(api.get)
@ -98,7 +99,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_average_session_interaction_statistic_requires_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
# This just verifies the decorator is applied correctly
# Actual endpoint testing would require complex JOIN mocking
@ -107,7 +108,7 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat
assert callable(method)
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_message_statistic_with_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = unwrap(api.get)
@ -123,7 +124,7 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_message_statistic_multiple_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = unwrap(api.get)
@ -142,7 +143,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
assert len(data["data"]) == 3
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_message_statistic_empty_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = unwrap(api.get)
@ -155,7 +156,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat
assert response.get_json() == {"data": []}
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_conversation_statistic_with_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = unwrap(api.get)
@ -174,7 +175,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_token_cost_with_multiple_currencies(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = unwrap(api.get)

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
import pytest
from flask import Flask
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
from models.account import Account, TenantAccountRole
@ -117,7 +118,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
assert response["has_more"] is False
def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rag_pipeline_workflow_patch_serializes_response_model(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = _make_workflow(marked_name="Updated release")
class _SessionContext:

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException, NotFound
from controllers.console.snippets import snippet_workflow as snippet_workflow_module
@ -52,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app):
view()
def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_snippet_injects_resolved_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = _snippet()
@snippet_workflow_module.get_snippet
@ -72,7 +73,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat
assert result is snippet
def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_snippet_raises_not_found_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@snippet_workflow_module.get_snippet
def view(**kwargs):
return kwargs
@ -89,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
view(snippet_id="snippet-1")
def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_workflow_get_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = _snippet()
monkeypatch.setattr(
snippet_workflow_module,
@ -105,7 +106,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
handler(api, snippet=snippet)
def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_workflow_post_returns_400_for_invalid_graph(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
user = _account("account-1")
snippet = _snippet()
sync_draft_workflow = Mock(side_effect=ValueError("invalid graph"))
@ -145,7 +146,7 @@ def test_published_workflow_get_returns_none_when_not_published(app) -> None:
assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None
def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_published_workflow_post_returns_400_when_publish_fails(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
user = _account("account-1")
snippet = _snippet()
merged_snippet = _snippet()
@ -180,7 +181,7 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
session.commit.assert_not_called()
def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_default_block_configs_delegates_to_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
get_default_block_configs = Mock(return_value=[{"type": "llm"}])
monkeypatch.setattr(
snippet_workflow_module,
@ -198,7 +199,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon
get_default_block_configs.assert_called_once()
def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_restore_published_snippet_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="restored-hash",
updated_at=None,
@ -226,7 +227,7 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
assert response["hash"] == "restored-hash"
def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_restore_published_snippet_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
user = _account("account-1")
snippet = _snippet()
@ -311,7 +312,7 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra
assert exc.value.description == "invalid snippet workflow graph"
def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_detail_raises_not_found_when_run_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = _snippet()
monkeypatch.setattr(
snippet_workflow_module,
@ -327,7 +328,9 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
handler(api, snippet=snippet, run_id="run-1")
def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_node_last_run_raises_not_found_when_execution_missing(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
snippet = _snippet()
draft_workflow = SimpleNamespace(id="workflow-1")
monkeypatch.setattr(
@ -347,7 +350,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey
handler(api, snippet=snippet, node_id="llm-1")
def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_task_stop_uses_queue_flag_and_graph_command(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
set_stop_flag = Mock()
send_stop_command = Mock()
monkeypatch.setattr(

View File

@ -27,7 +27,7 @@ def _make_account() -> Account:
@pytest.fixture(autouse=True)
def _patch_snippet_service_factory(monkeypatch):
def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch):
def factory():
service_factory = module.SnippetService
if isinstance(service_factory, type):
@ -64,7 +64,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable(
module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1")
def test_conversation_variables_returns_empty_list(app):
def test_conversation_variables_returns_empty_list(app: Flask):
api = module.SnippetConversationVariableCollectionApi()
handler = _unwrap(api.get)
@ -74,7 +74,7 @@ def test_conversation_variables_returns_empty_list(app):
assert result == WorkflowDraftVariableList(variables=[])
def test_system_variables_returns_empty_list(app):
def test_system_variables_returns_empty_list(app: Flask):
api = module.SnippetSystemVariableCollectionApi()
handler = _unwrap(api.get)

View File

@ -4,6 +4,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from controllers.console import init_validate
from controllers.console.error import AlreadySetupError, InitValidateFailedError
@ -35,7 +36,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.status == "not_started"
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
app.secret_key = "test-secret"
@ -45,7 +46,7 @@ def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPat
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
@ -57,7 +58,7 @@ def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPa
assert init_validate.session.get("is_init_validated") is False
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
@ -74,7 +75,7 @@ def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatc
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_init_validate_status_validated_session(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
@ -84,7 +85,7 @@ def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.Mon
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_init_validate_status_setup_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
@ -96,7 +97,7 @@ def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPa
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_init_validate_status_not_validated(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock
import httpx
import pytest
from flask import Flask
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
@ -82,7 +83,7 @@ def _mock_upload_dependencies(
return file_service_cls, current_user
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_remote_file_info_uses_head_when_successful(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = unwrap(api.get)
decoded_url = "https://example.com/test.txt"
@ -103,7 +104,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
make_request.assert_called_once_with("HEAD", decoded_url)
def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_remote_file_info_preserves_unencoded_target_query(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = unwrap(api.get)
target_url = "http://example.com/api/aiagent/httpview/txt"
@ -124,7 +125,9 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
make_request.assert_called_once_with("HEAD", f"{target_url}?{query}")
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = unwrap(api.get)
decoded_url = "https://example.com/test.txt"
@ -147,7 +150,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
assert make_request.call_args_list[1].kwargs == {"timeout": 3}
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/report.txt"
@ -220,7 +223,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/fail.txt"
@ -238,7 +241,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
handler(api, _make_account())
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_raises_on_httpx_request_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/fail.txt"
@ -252,7 +255,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
handler(api, _make_account())
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_rejects_oversized_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/large.bin"
@ -267,7 +270,9 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
handler(api, current_user)
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_translates_service_file_too_large_error(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/large.bin"
@ -282,7 +287,9 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
handler(api, current_user)
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_remote_file_upload_translates_service_unsupported_type_error(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = unwrap(api.post)
url = "https://example.com/file.exe"

View File

@ -102,10 +102,10 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
def test_should_normalize_new_email_phase(
self,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_send_email: MagicMock,
mock_get_change_data: MagicMock,
app: Flask,
):
mock_account = _build_account("current@example.com", "acc1")
@ -143,10 +143,10 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
self,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_send_email: MagicMock,
mock_get_change_data: MagicMock,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
@ -178,10 +178,10 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user(
self,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_send_email: MagicMock,
mock_get_change_data: MagicMock,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError

View File

@ -603,7 +603,7 @@ class TestRateLimiting:
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
def test_should_allow_requests_within_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock):
"""Test that requests within rate limit are allowed"""
# Arrange
mock_rate_limit = MagicMock()
@ -631,7 +631,7 @@ class TestRateLimiting:
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
def test_should_reject_requests_over_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock):
"""Test that requests over rate limit are rejected and logged"""
# Arrange
app = create_app_with_login()
@ -720,7 +720,7 @@ class TestSystemSetup:
"""Test system setup decorator"""
@patch("controllers.console.wraps.db")
def test_should_allow_when_setup_complete(self, mock_db):
def test_should_allow_when_setup_complete(self, mock_db: MagicMock):
"""Test that requests are allowed when setup is complete"""
# Arrange
@ -737,7 +737,7 @@ class TestSystemSetup:
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db: MagicMock):
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
# Arrange
mock_db.session.scalar.return_value = None # No setup
@ -754,7 +754,7 @@ class TestSystemSetup:
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db: MagicMock):
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
# Arrange
mock_db.session.scalar.return_value = None # No setup

View File

@ -3,6 +3,7 @@ from types import SimpleNamespace
from unittest.mock import ANY, Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console.workspace import snippets as snippets_module
@ -69,7 +70,7 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values():
}
def test_list_snippets_returns_pagination(app, monkeypatch):
def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippets = [_snippet()]
tag_id = "11111111-1111-1111-1111-111111111111"
get_snippets = Mock(return_value=(snippets, 1, False))
@ -104,7 +105,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
)
def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch):
def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
snippet = _snippet()
create_snippet = Mock(return_value=snippet)
@ -140,7 +141,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE
def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
def test_create_snippet_rejects_forbidden_nodes(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
create_snippet = Mock()
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
@ -169,7 +170,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
create_snippet.assert_not_called()
def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetDetailApi()
@ -180,7 +181,7 @@ def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
handler(api, "tenant-1", snippet_id="snippet-1")
def test_get_snippet_detail_returns_snippet(app, monkeypatch):
def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet()
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
@ -195,7 +196,7 @@ def test_get_snippet_detail_returns_snippet(app, monkeypatch):
assert response == {"id": "snippet-1"}
def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet()
user = _account("user-1")
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
@ -214,7 +215,7 @@ def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
assert response == {"message": "No valid fields to update"}
def test_patch_snippet_updates_and_commits(app, monkeypatch):
def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
snippet = _snippet()
updated_snippet = _snippet(name="New")
@ -251,7 +252,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
session.commit.assert_called_once()
def test_delete_snippet_deletes_and_commits(app, monkeypatch):
def test_delete_snippet_deletes_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet()
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
delete_snippet = Mock()
@ -277,7 +278,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
session.commit.assert_called_once()
def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
def test_export_snippet_returns_yaml_attachment(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet(name="Snippet One")
export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n")
session = SimpleNamespace()
@ -308,7 +309,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
export_snippet_dsl.assert_called_once_with(snippet=snippet, include_secret=True)
def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
def test_import_snippet_returns_202_for_pending_confirmation(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0")
import_snippet = Mock(return_value=result)
@ -348,7 +349,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
session.commit.assert_called_once()
def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
def test_import_snippet_returns_400_for_failed_import(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL")
import_snippet = Mock(return_value=result)
@ -381,7 +382,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
session.commit.assert_called_once()
def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
def test_import_confirm_returns_200_for_completed_import(app: Flask, monkeypatch: pytest.MonkeyPatch):
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1")
confirm_import = Mock(return_value=result)
@ -414,7 +415,7 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
session.commit.assert_called_once()
def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
def test_check_dependencies_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
@ -425,7 +426,7 @@ def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
handler(api, "tenant-1", snippet_id="snippet-1")
def test_check_dependencies_returns_dependency_result(app, monkeypatch):
def test_check_dependencies_returns_dependency_result(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet()
check_dependencies = Mock(
return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []}))
@ -456,7 +457,7 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
check_dependencies.assert_called_once_with(snippet=snippet)
def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
def test_increment_use_count_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
@ -470,7 +471,7 @@ def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
handler(api, "tenant-1", snippet_id="snippet-1")
def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
def test_increment_use_count_returns_refreshed_count(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=2)
merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=3)
session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock(), refresh=Mock())

View File

@ -35,7 +35,7 @@ def _stub_execute(
@pytest.fixture
def bypass_pipeline(monkeypatch):
def bypass_pipeline(monkeypatch: pytest.MonkeyPatch):
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
Module-level @auth_router.guard(...) captures the real router at import

View File

@ -167,14 +167,14 @@ def _session_auth_data() -> AuthData:
)
def _stub_session_deps(monkeypatch, rows):
def _stub_session_deps(monkeypatch: pytest.MonkeyPatch, rows):
mod = sys.modules[_ACCOUNT_MOD]
monkeypatch.setattr(mod, "get_auth_ctx", lambda: SimpleNamespace())
monkeypatch.setattr(mod, "list_active_sessions", lambda *args, **kwargs: rows)
monkeypatch.setattr(mod, "db", MagicMock())
def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch):
def test_sessions_list_valid_query_parses_page_and_limit(app: Flask, monkeypatch: pytest.MonkeyPatch):
"""A valid ?page&limit round-trips through SessionListQuery into the response envelope."""
api = AccountSessionsApi()
_stub_session_deps(monkeypatch, [])
@ -187,7 +187,7 @@ def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch):
assert body["data"] == []
def test_sessions_list_defaults_when_query_omitted(app, monkeypatch):
def test_sessions_list_defaults_when_query_omitted(app: Flask, monkeypatch: pytest.MonkeyPatch):
"""No query → the model's defaults (page=1, limit=100) drive the envelope."""
api = AccountSessionsApi()
_stub_session_deps(monkeypatch, [])
@ -209,7 +209,7 @@ def test_sessions_list_defaults_when_query_omitted(app, monkeypatch):
"foo=bar", # extra='forbid'
],
)
def test_sessions_list_rejects_out_of_bounds_query(app, monkeypatch, query):
def test_sessions_list_rejects_out_of_bounds_query(app: Flask, monkeypatch: pytest.MonkeyPatch, query):
"""Out-of-range / unknown query params raise 422 instead of being silently coerced."""
api = AccountSessionsApi()
_stub_session_deps(monkeypatch, [])

View File

@ -6,6 +6,9 @@ import sys
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from controllers.openapi._models import AppRunRequest
@ -30,7 +33,9 @@ def test_app_run_request_with_query():
assert req.query == "hello"
def test_run_chat_always_calls_generate_with_streaming_true(app, bypass_pipeline, monkeypatch):
def test_run_chat_always_calls_generate_with_streaming_true(
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
"""_run_chat must always invoke AppGenerateService.generate with streaming=True."""
from controllers.openapi.app_run import _run_chat
@ -56,7 +61,7 @@ def test_stop_task_endpoint_registered(openapi_app):
assert "/openapi/v1/apps/<string:app_id>/tasks/<string:task_id>/stop" in rules
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
def test_stop_task_calls_queue_manager_and_graph_engine(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
import uuid
from controllers.openapi.app_run import AppRunTaskStopApi

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import pytest
def test_version_endpoint_returns_200_without_auth(openapi_app):
client = openapi_app.test_client()
@ -30,7 +32,7 @@ def test_version_endpoint_ignores_bearer_header(openapi_app):
assert "edition" in payload
def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch):
def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch: pytest.MonkeyPatch):
from configs import dify_config
monkeypatch.setattr(dify_config, "EDITION", "CLOUD")
@ -42,7 +44,7 @@ def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch):
assert response.get_json()["edition"] == "CLOUD"
def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch):
def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch: pytest.MonkeyPatch):
from configs import dify_config
monkeypatch.setattr(dify_config, "EDITION", "EXPERIMENTAL")

View File

@ -8,6 +8,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
@ -51,7 +52,7 @@ class TestOpenApiWorkflowEventsApi:
return OpenApiWorkflowEventsApi()
def test_not_found_when_run_missing(self, app, bypass_pipeline, monkeypatch):
def test_not_found_when_run_missing(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
module = sys.modules["controllers.openapi.workflow_events"]
repo_mock = Mock()
repo_mock.get_workflow_run_by_id_and_tenant_id.return_value = None
@ -76,7 +77,9 @@ class TestOpenApiWorkflowEventsApi:
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
def test_not_found_when_run_belongs_to_different_app(
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
module = sys.modules["controllers.openapi.workflow_events"]
run = _make_workflow_run(app_id="other-app")
repo_mock = Mock()
@ -102,7 +105,9 @@ class TestOpenApiWorkflowEventsApi:
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
def test_account_caller_checks_created_by_account(
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
"""Account caller must match created_by == caller.id and role == ACCOUNT."""
module = sys.modules["controllers.openapi.workflow_events"]
run = _make_workflow_run(created_by_role=CreatorUserRole.ACCOUNT, created_by="acct-1")
@ -141,7 +146,9 @@ class TestOpenApiWorkflowEventsApi:
)
assert resp.mimetype == "text/event-stream"
def test_account_caller_rejected_for_end_user_run(self, app, bypass_pipeline, monkeypatch):
def test_account_caller_rejected_for_end_user_run(
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
module = sys.modules["controllers.openapi.workflow_events"]
run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1")
repo_mock = Mock()
@ -167,7 +174,9 @@ class TestOpenApiWorkflowEventsApi:
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
def test_end_user_caller_checks_created_by_end_user(
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
"""End-user caller must match created_by == caller.id and role == END_USER."""
module = sys.modules["controllers.openapi.workflow_events"]
run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1")
@ -202,7 +211,7 @@ class TestOpenApiWorkflowEventsApi:
)
assert resp.mimetype == "text/event-stream"
def test_finished_run_returns_single_sse_event(self, app, bypass_pipeline, monkeypatch):
def test_finished_run_returns_single_sse_event(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""A finished run returns a single done-event SSE response without streaming."""
from datetime import UTC, datetime

View File

@ -34,6 +34,7 @@ from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity
from controllers.openapi import bp as openapi_bp
from controllers.openapi._errors import MemberLicenseExceeded, MemberLimitExceeded
from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload
from controllers.openapi.auth.data import AuthData
from controllers.openapi.workspaces import (
WorkspaceMemberApi,
WorkspaceMemberRoleApi,
@ -228,7 +229,7 @@ def test_role_payload_rejects_extra_field():
MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"})
def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
def test_invite_rejects_invalid_body_with_422(app: Flask, bypass_pipeline):
"""Invalid invite body → 422 via @accepts (was 400 through _validate_body)."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -245,7 +246,7 @@ def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
def test_update_role_rejects_invalid_body_with_422(app: Flask, bypass_pipeline):
"""Invalid role-update body surfaces as 422 through @accepts (was 400)."""
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -267,7 +268,9 @@ def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
# ---------------------------------------------------------------------------
def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, monkeypatch):
def test_switch_returns_workspace_detail_with_current_true(
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
"""Happy path: switch service is called, then the workspace+membership
row is re-queried so the returned `current` reflects post-commit state.
"""
@ -298,7 +301,9 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
assert switch_mock.called
def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pipeline, monkeypatch):
def test_switch_404s_when_service_raises_account_not_link_tenant(
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
"""If switch_tenant raises (e.g. Tenant.status != NORMAL), the body
surfaces as NotFound, not 500."""
ws_id = str(uuid.uuid4())
@ -326,7 +331,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
# ---------------------------------------------------------------------------
def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch):
def test_members_list_returns_normalized_rows(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMembersApi()
@ -364,7 +369,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
assert body["data"][0]["status"] == "active"
def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypatch):
def test_members_list_paginates_with_query_params(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""`?page=2&limit=2` slices service output and reports total/has_more."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -404,7 +409,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
assert [d["id"] for d in body["data"]] == ["m-2", "m-3"]
def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch):
def test_members_list_rejects_unknown_query_param(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts)."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -425,7 +430,9 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
# ---------------------------------------------------------------------------
def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline, monkeypatch):
def test_invite_happy_path_returns_invite_url_and_member_id(
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
):
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMembersApi()
@ -507,7 +514,7 @@ def _invite_request(app, ws_id: str, acct_id: uuid.UUID):
)
def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
def test_invite_blocked_by_saas_members_cap(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""SaaS billing plan member cap → MemberLimitExceeded (403)."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -541,7 +548,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
invite_mock.assert_not_called()
def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, monkeypatch):
def test_invite_blocked_by_ee_workspace_members_license(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""EE License workspace_members cap → MemberLicenseExceeded (403).
Note: billing.enabled is False (EE without SaaS billing); only the
@ -583,7 +590,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
invite_mock.assert_not_called()
def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypatch):
def test_invite_ce_passes_when_both_caps_disabled(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""CE deployment (no billing, no license) → quota gate is a no-op,
invite proceeds normally."""
ws_id = str(uuid.uuid4())
@ -619,7 +626,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
assert body["email"] == "new@example.com"
def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
def test_invite_400_when_already_in_tenant(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMembersApi()
@ -650,7 +657,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
# ---------------------------------------------------------------------------
def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
def test_delete_member_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberApi()
@ -692,7 +699,7 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
(MemberNotInTenantError("not in tenant"), NotFound),
],
)
def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected):
def test_delete_member_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected):
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberApi()
@ -725,7 +732,7 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
)
def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch):
def test_delete_member_404_when_member_missing(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberApi()
@ -757,7 +764,7 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
# ---------------------------------------------------------------------------
def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
def test_update_role_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberRoleApi()
@ -801,7 +808,7 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
(MemberNotInTenantError("not in tenant"), NotFound),
],
)
def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected):
def test_update_role_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected):
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberRoleApi()
@ -841,7 +848,7 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
# ---------------------------------------------------------------------------
def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatch):
def test_load_tenant_rejects_archived_workspace(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""Member management against an archived workspace → 404."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
@ -869,7 +876,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
# ---------------------------------------------------------------------------
def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
def test_invite_400_when_register_error(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
"""AccountRegisterError (frozen email, workspace creation blocked) → 400."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()

View File

@ -81,7 +81,7 @@ class TestPluginAgentStrategyInitialization:
class TestGetParameters:
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
def test_get_parameters_returns_parameters(self, strategy: PluginAgentStrategy, mock_declaration) -> None:
result = strategy.get_parameters()
assert result == mock_declaration.parameters
@ -92,7 +92,7 @@ class TestGetParameters:
class TestInitializeParameters:
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
def test_initialize_parameters_success(self, strategy: PluginAgentStrategy, mock_declaration) -> None:
params = {"param1": "value1"}
result = strategy.initialize_parameters(params.copy())
@ -114,13 +114,13 @@ class TestInitializeParameters:
{"param1": {}, "param2": "value"},
],
)
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
def test_initialize_parameters_edge_cases(self, strategy: PluginAgentStrategy, input_params) -> None:
result = strategy.initialize_parameters(input_params.copy())
for param in strategy.declaration.parameters:
assert param.name in result
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
def test_initialize_parameters_invalid_input_type(self, strategy: PluginAgentStrategy) -> None:
with pytest.raises(AttributeError):
strategy.initialize_parameters(None)
@ -131,7 +131,7 @@ class TestInitializeParameters:
class TestInvoke:
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
def test_invoke_success_all_arguments(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
@ -171,7 +171,7 @@ class TestInvoke:
assert call_kwargs["message_id"] == "msg_1"
assert call_kwargs["context"] is not None
def test_invoke_with_credentials(self, strategy, mocker) -> None:
def test_invoke_with_credentials(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))
@ -243,7 +243,7 @@ class TestInvoke:
assert result == []
mock_manager.invoke.assert_called_once()
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
def test_invoke_convert_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=MagicMock(),
@ -257,7 +257,7 @@ class TestInvoke:
with pytest.raises(ValueError):
list(strategy._invoke(params={}, user_id="user_1"))
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
def test_invoke_manager_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
mock_manager = MagicMock()
mock_manager.invoke.side_effect = RuntimeError("invoke failed")

View File

@ -42,13 +42,13 @@ def runner(mocker: MockerFixture, mock_db_session):
class TestRepack:
def test_sets_empty_if_none(self, runner, mocker: MockerFixture):
def test_sets_empty_if_none(self, runner: BaseAgentRunner, mocker: MockerFixture):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = None
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == ""
def test_keeps_existing(self, runner, mocker: MockerFixture):
def test_keeps_existing(self, runner: BaseAgentRunner, mocker: MockerFixture):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = "abc"
result = runner._repack_app_generate_entity(entity)
@ -61,7 +61,7 @@ class TestRepack:
class TestUpdatePromptTool:
def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner, mocker: MockerFixture):
def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner: BaseAgentRunner, mocker: MockerFixture):
tool = mocker.MagicMock()
schema = {
"type": "object",
@ -83,7 +83,7 @@ class TestUpdatePromptTool:
class TestCreateAgentThought:
def test_with_files(self, runner, mock_db_session, mocker: MockerFixture):
def test_with_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_thought = mocker.MagicMock(id=10)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
@ -91,7 +91,7 @@ class TestCreateAgentThought:
assert result == "10"
assert runner.agent_thought_count == 1
def test_without_files(self, runner, mock_db_session, mocker: MockerFixture):
def test_without_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_thought = mocker.MagicMock(id=11)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
@ -112,12 +112,12 @@ class TestSaveAgentThought:
agent.thought = ""
return agent
def test_not_found(self, runner, mock_db_session):
def test_not_found(self, runner: BaseAgentRunner, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError):
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
def test_full_update(self, runner, mock_db_session, mocker: MockerFixture):
def test_full_update(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -152,7 +152,7 @@ class TestSaveAgentThought:
assert agent.tokens == 3
assert "tool1" in json.loads(agent.tool_labels_str)
def test_label_fallback_when_none(self, runner, mock_db_session, mocker: MockerFixture):
def test_label_fallback_when_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
agent.tool = "unknown_tool"
mock_db_session.scalar.return_value = agent
@ -162,7 +162,7 @@ class TestSaveAgentThought:
labels = json.loads(agent.tool_labels_str)
assert "unknown_tool" in labels
def test_json_failure_paths(self, runner, mock_db_session, mocker: MockerFixture):
def test_json_failure_paths(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -183,13 +183,13 @@ class TestSaveAgentThought:
assert mock_db_session.commit.called
def test_messages_ids_none(self, runner, mock_db_session, mocker: MockerFixture):
def test_messages_ids_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
assert mock_db_session.commit.called
def test_success_dict_serialization(self, runner, mock_db_session, mocker: MockerFixture):
def test_success_dict_serialization(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
@ -215,19 +215,19 @@ class TestSaveAgentThought:
class TestOrganizeUserPrompt:
def test_no_files(self, runner, mock_db_session, mocker: MockerFixture):
def test_no_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = []
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_with_files_no_config(self, runner, mock_db_session, mocker: MockerFixture):
def test_with_files_no_config(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker: MockerFixture):
def test_image_detail_low_fallback(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
@ -247,27 +247,27 @@ class TestOrganizeUserPrompt:
class TestOrganizeHistory:
def test_empty(self, runner, mock_db_session, mocker: MockerFixture):
def test_empty(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
result = runner.organize_agent_history([])
assert result == []
def test_with_answer_only(self, runner, mock_db_session, mocker: MockerFixture):
def test_with_answer_only(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
def test_skip_current_message(self, runner, mock_db_session, mocker: MockerFixture):
def test_skip_current_message(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert result == []
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker: MockerFixture):
def test_with_tool_calls_invalid_json(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1",
tool_input="invalid",
@ -283,7 +283,7 @@ class TestOrganizeHistory:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_empty_tool_name_split(self, runner, mock_db_session, mocker: MockerFixture):
def test_empty_tool_name_split(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(tool=";", thought="thinking")
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
@ -292,7 +292,7 @@ class TestOrganizeHistory:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker: MockerFixture):
def test_valid_json_tool_flow(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(
tool="tool1",
tool_input=json.dumps({"tool1": {"x": 1}}),
@ -321,7 +321,7 @@ class TestOrganizeHistory:
class TestConvertToolToPromptMessageTool:
def test_basic_conversion(self, runner, mocker: MockerFixture):
def test_basic_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
tool_entity = mocker.MagicMock()
@ -347,7 +347,7 @@ class TestConvertToolToPromptMessageTool:
class TestInitPromptToolsExtended:
def test_agent_tool_branch(self, runner, mocker: MockerFixture):
def test_agent_tool_branch(self, runner: BaseAgentRunner, mocker: MockerFixture):
agent_tool = mocker.MagicMock(tool_name="agent_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
@ -355,7 +355,7 @@ class TestInitPromptToolsExtended:
tools, prompts = runner._init_prompt_tools()
assert "agent_tool" in tools
def test_exception_in_conversion(self, runner, mocker: MockerFixture):
def test_exception_in_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture):
agent_tool = mocker.MagicMock(tool_name="bad_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
@ -370,7 +370,7 @@ class TestInitPromptToolsExtended:
class TestAdditionalCoverage:
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture):
def test_save_agent_thought_existing_labels(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {"tool1": {"en_US": "existing"}}
@ -381,7 +381,7 @@ class TestAdditionalCoverage:
labels = json.loads(agent.tool_labels_str)
assert labels["tool1"]["en_US"] == "existing"
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker: MockerFixture):
def test_save_agent_thought_tool_meta_string(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
@ -391,7 +391,7 @@ class TestAdditionalCoverage:
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
assert agent.tool_meta_str == "meta_string"
def test_convert_dataset_retriever_tool(self, runner, mocker: MockerFixture):
def test_convert_dataset_retriever_tool(self, runner: BaseAgentRunner, mocker: MockerFixture):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
@ -408,7 +408,9 @@ class TestAdditionalCoverage:
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker: MockerFixture):
def test_organize_user_prompt_with_file_objects(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
@ -427,7 +429,7 @@ class TestAdditionalCoverage:
result = runner.organize_agent_user_prompt(msg)
assert result is not None
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker: MockerFixture):
def test_organize_history_without_tool_names(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
thought = mocker.MagicMock(tool=None, thought="thinking")
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
@ -437,7 +439,9 @@ class TestAdditionalCoverage:
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker: MockerFixture):
def test_organize_history_multiple_tools_split(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
thought = mocker.MagicMock(
tool="tool1;tool2",
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
@ -455,7 +459,7 @@ class TestAdditionalCoverage:
class TestConvertDatasetRetrieverTool:
def test_required_param_added(self, runner, mocker: MockerFixture):
def test_required_param_added(self, runner: BaseAgentRunner, mocker: MockerFixture):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
@ -518,7 +522,7 @@ class TestBaseAgentRunnerInit:
class TestBaseAgentRunnerCoverage:
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture):
def test_init_prompt_tools_adds_dataset_tools(self, runner: BaseAgentRunner, mocker: MockerFixture):
dataset_tool = mocker.MagicMock()
dataset_tool.entity.identity.name = "ds"
runner.dataset_tools = [dataset_tool]
@ -530,7 +534,9 @@ class TestBaseAgentRunnerCoverage:
assert tools["ds"] == dataset_tool
assert len(prompt_tools) == 1
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture):
def test_save_agent_thought_json_dumps_fallbacks(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
@ -568,7 +574,9 @@ class TestBaseAgentRunnerCoverage:
assert isinstance(agent.observation, str)
assert isinstance(agent.tool_meta_str, str)
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker: MockerFixture):
def test_save_agent_thought_skips_empty_tool_name(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
agent = mocker.MagicMock()
agent.tool = "tool1;;"
agent.tool_labels = {}
@ -582,7 +590,9 @@ class TestBaseAgentRunnerCoverage:
labels = json.loads(agent.tool_labels_str)
assert "" not in labels
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker: MockerFixture):
def test_organize_history_includes_system_prompt(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
@ -592,7 +602,9 @@ class TestBaseAgentRunnerCoverage:
assert system_message in result
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker: MockerFixture):
def test_organize_history_tool_inputs_and_observation_none(
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
):
thought = mocker.MagicMock(
tool="tool1",
tool_input=None,

View File

@ -95,25 +95,25 @@ class TestFillInputs:
("", {"x": "y"}, ""),
],
)
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
def test_fill_in_inputs(self, runner: DummyRunner, instruction, inputs, expected):
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
assert result == expected
class TestConvertDictToAction:
def test_convert_valid_dict(self, runner):
def test_convert_valid_dict(self, runner: DummyRunner):
action_dict = {"action": "test", "action_input": {"a": 1}}
action = runner._convert_dict_to_action(action_dict)
assert action.action_name == "test"
assert action.action_input == {"a": 1}
def test_convert_missing_keys(self, runner):
def test_convert_missing_keys(self, runner: DummyRunner):
with pytest.raises(KeyError):
runner._convert_dict_to_action({"invalid": 1})
class TestFormatAssistantMessage:
def test_format_assistant_message_multiple_scratchpads(self, runner):
def test_format_assistant_message_multiple_scratchpads(self, runner: DummyRunner):
sp1 = AgentScratchpadUnit(
agent_response="resp1",
thought="thought1",
@ -131,7 +131,7 @@ class TestFormatAssistantMessage:
result = runner._format_assistant_message([sp1, sp2])
assert "Final Answer:" in result
def test_format_with_final(self, runner):
def test_format_with_final(self, runner: DummyRunner):
scratchpad = AgentScratchpadUnit(
agent_response="Done",
thought="",
@ -144,7 +144,7 @@ class TestFormatAssistantMessage:
result = runner._format_assistant_message([scratchpad])
assert "Final Answer" in result
def test_format_with_action_and_observation(self, runner):
def test_format_with_action_and_observation(self, runner: DummyRunner):
scratchpad = AgentScratchpadUnit(
agent_response="resp",
thought="thinking",
@ -161,12 +161,12 @@ class TestFormatAssistantMessage:
class TestHandleInvokeAction:
def test_handle_invoke_action_tool_not_present(self, runner):
def test_handle_invoke_action_tool_not_present(self, runner: DummyRunner):
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
response, meta = runner._handle_invoke_action(action, {}, [])
assert "there is not a tool named" in response
def test_tool_with_json_string_args(self, runner, mocker: MockerFixture):
def test_tool_with_json_string_args(self, runner: DummyRunner, mocker: MockerFixture):
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
@ -181,7 +181,7 @@ class TestHandleInvokeAction:
class TestOrganizeHistoricPromptMessages:
def test_empty_history(self, runner, mocker: MockerFixture):
def test_empty_history(self, runner: DummyRunner, mocker: MockerFixture):
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
return_value=[],
@ -191,7 +191,7 @@ class TestOrganizeHistoricPromptMessages:
class TestRun:
def test_run_handles_empty_parser_output(self, runner, mocker: MockerFixture):
def test_run_handles_empty_parser_output(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -203,7 +203,7 @@ class TestRun:
results = list(runner.run(message, "query", {}))
assert isinstance(results, list)
def test_run_with_action_and_tool_invocation(self, runner, mocker: MockerFixture):
def test_run_with_action_and_tool_invocation(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -224,7 +224,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_respects_max_iteration_boundary(self, runner, mocker: MockerFixture):
def test_run_respects_max_iteration_boundary(self, runner: DummyRunner, mocker: MockerFixture):
runner.app_config.agent.max_iteration = 1
message = MagicMock()
message.id = "msg-id"
@ -246,7 +246,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_basic_flow(self, runner, mocker: MockerFixture):
def test_run_basic_flow(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -258,7 +258,7 @@ class TestRun:
results = list(runner.run(message, "query", {"name": "John"}))
assert results
def test_run_max_iteration_error(self, runner, mocker: MockerFixture):
def test_run_max_iteration_error(self, runner: DummyRunner, mocker: MockerFixture):
runner.app_config.agent.max_iteration = 0
message = MagicMock()
message.id = "msg-id"
@ -273,7 +273,7 @@ class TestRun:
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {}))
def test_run_increase_usage_aggregation(self, runner, mocker: MockerFixture):
def test_run_increase_usage_aggregation(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
runner.app_config.agent.max_iteration = 2
@ -330,7 +330,7 @@ class TestRun:
assert final_usage.completion_price == 2
assert final_usage.total_price == 4
def test_run_when_no_action_branch(self, runner, mocker: MockerFixture):
def test_run_when_no_action_branch(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -342,7 +342,7 @@ class TestRun:
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == ""
def test_run_usage_missing_key_branch(self, runner, mocker: MockerFixture):
def test_run_usage_missing_key_branch(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -355,7 +355,7 @@ class TestRun:
list(runner.run(message, "query", {}))
def test_run_prompt_tool_update_branch(self, runner, mocker: MockerFixture):
def test_run_prompt_tool_update_branch(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -387,7 +387,7 @@ class TestRun:
runner.update_prompt_message_tool.assert_called_once()
def test_historic_with_assistant_and_tool_calls(self, runner):
def test_historic_with_assistant_and_tool_calls(self, runner: DummyRunner):
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
assistant = AssistantPromptMessage(content="thinking")
@ -400,7 +400,7 @@ class TestRun:
result = runner._organize_historic_prompt_messages([])
assert isinstance(result, list)
def test_historic_final_flush_branch(self, runner):
def test_historic_final_flush_branch(self, runner: DummyRunner):
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
assistant = AssistantPromptMessage(content="final")
@ -411,7 +411,7 @@ class TestRun:
class TestInitReactState:
def test_init_react_state_resets_state(self, runner, mocker: MockerFixture):
def test_init_react_state_resets_state(self, runner: DummyRunner, mocker: MockerFixture):
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
runner._agent_scratchpad = ["old"]
runner._query = "old"
@ -424,7 +424,7 @@ class TestInitReactState:
class TestHandleInvokeActionExtended:
def test_tool_with_invalid_json_string_args(self, runner, mocker: MockerFixture):
def test_tool_with_invalid_json_string_args(self, runner: DummyRunner, mocker: MockerFixture):
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
@ -443,11 +443,11 @@ class TestHandleInvokeActionExtended:
class TestFillInputsEdgeCases:
def test_fill_inputs_with_empty_inputs(self, runner):
def test_fill_inputs_with_empty_inputs(self, runner: DummyRunner):
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
assert result == "Hello {{x}}"
def test_fill_inputs_with_exception_in_replace(self, runner):
def test_fill_inputs_with_exception_in_replace(self, runner: DummyRunner):
class BadValue:
def __str__(self):
raise Exception("fail")
@ -458,7 +458,7 @@ class TestFillInputsEdgeCases:
class TestOrganizeHistoricPromptMessagesExtended:
def test_user_message_flushes_scratchpad(self, runner, mocker: MockerFixture):
def test_user_message_flushes_scratchpad(self, runner: DummyRunner, mocker: MockerFixture):
from graphon.model_runtime.entities.message_entities import UserPromptMessage
user_message = UserPromptMessage(content="Hi")
@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
result = runner._organize_historic_prompt_messages([])
assert result == ["final"]
def test_tool_message_without_scratchpad_raises(self, runner):
def test_tool_message_without_scratchpad_raises(self, runner: DummyRunner):
from graphon.model_runtime.entities.message_entities import ToolPromptMessage
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
@ -481,7 +481,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
with pytest.raises(NotImplementedError):
runner._organize_historic_prompt_messages([])
def test_agent_history_transform_invocation(self, runner, mocker: MockerFixture):
def test_agent_history_transform_invocation(self, runner: DummyRunner, mocker: MockerFixture):
mock_transform = MagicMock()
mock_transform.get_prompt.return_value = []
@ -496,7 +496,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
class TestRunAdditionalBranches:
def test_run_with_no_action_final_answer_empty(self, runner, mocker: MockerFixture):
def test_run_with_no_action_final_answer_empty(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -508,7 +508,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert any(hasattr(r, "delta") for r in results)
def test_run_with_final_answer_action_string(self, runner, mocker: MockerFixture):
def test_run_with_final_answer_action_string(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -522,7 +522,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "done"
def test_run_with_final_answer_action_dict(self, runner, mocker: MockerFixture):
def test_run_with_final_answer_action_dict(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"
@ -536,7 +536,7 @@ class TestRunAdditionalBranches:
results = list(runner.run(message, "query", {}))
assert json.loads(results[-1].delta.message.content) == {"a": 1}
def test_run_with_string_final_answer(self, runner, mocker: MockerFixture):
def test_run_with_string_final_answer(self, runner: DummyRunner, mocker: MockerFixture):
message = MagicMock()
message.id = "msg-id"

View File

@ -40,7 +40,11 @@ def runner(mocker: MockerFixture, dummy_tool_factory):
class TestOrganizeInstructionPrompt:
def test_success_all_placeholders(
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
self,
runner: CotCompletionAgentRunner,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
):
template = (
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
@ -58,12 +62,14 @@ class TestOrganizeInstructionPrompt:
tools_payload = json.loads(result.split(" | ")[1])
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
def test_agent_none_raises(self, runner, dummy_app_config_factory):
def test_agent_none_raises(self, runner: CotCompletionAgentRunner, dummy_app_config_factory):
runner.app_config = dummy_app_config_factory(agent=None)
with pytest.raises(ValueError, match="Agent configuration is not set"):
runner._organize_instruction_prompt()
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
def test_prompt_entity_none_raises(
self, runner: CotCompletionAgentRunner, dummy_app_config_factory, dummy_agent_config_factory
):
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
with pytest.raises(ValueError, match="prompt entity is not set"):
runner._organize_instruction_prompt()
@ -75,7 +81,7 @@ class TestOrganizeInstructionPrompt:
class TestOrganizeHistoricPrompt:
def test_with_user_and_assistant_string(self, runner, mocker: MockerFixture):
def test_with_user_and_assistant_string(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
user_msg = UserPromptMessage(content="Hello")
assistant_msg = AssistantPromptMessage(content="Hi there")
@ -90,7 +96,7 @@ class TestOrganizeHistoricPrompt:
assert "Question: Hello" in result
assert "Hi there" in result
def test_assistant_list_with_text_content(self, runner, mocker: MockerFixture):
def test_assistant_list_with_text_content(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
text_content = TextPromptMessageContent(data="Partial answer")
assistant_msg = AssistantPromptMessage(content=[text_content])
@ -104,7 +110,9 @@ class TestOrganizeHistoricPrompt:
assert "Partial answer" in result
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker: MockerFixture):
def test_assistant_list_with_non_text_content_ignored(
self, runner: CotCompletionAgentRunner, mocker: MockerFixture
):
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
assistant_msg = AssistantPromptMessage(content=[non_text_content])
@ -117,7 +125,7 @@ class TestOrganizeHistoricPrompt:
result = runner._organize_historic_prompt()
assert result == ""
def test_empty_history(self, runner, mocker: MockerFixture):
def test_empty_history(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",

View File

@ -147,12 +147,12 @@ def runner(mocker: MockerFixture):
class TestToolCallChecks:
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_tool_calls(self, runner, tool_calls, expected):
def test_check_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected):
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_tool_calls(chunk) is expected
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
def test_check_blocking_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected):
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_blocking_tool_calls(result) is expected
@ -163,7 +163,7 @@ class TestToolCallChecks:
class TestExtractToolCalls:
def test_extract_tool_calls_with_valid_json(self, runner):
def test_extract_tool_calls_with_valid_json(self, runner: FunctionCallAgentRunner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
@ -174,7 +174,7 @@ class TestExtractToolCalls:
assert calls == [("1", "tool", {"a": 1})]
def test_extract_tool_calls_empty_arguments(self, runner):
def test_extract_tool_calls_empty_arguments(self, runner: FunctionCallAgentRunner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
@ -185,7 +185,7 @@ class TestExtractToolCalls:
assert calls == [("1", "tool", {})]
def test_extract_blocking_tool_calls(self, runner):
def test_extract_blocking_tool_calls(self, runner: FunctionCallAgentRunner):
tool_call = MagicMock()
tool_call.id = "2"
tool_call.function.name = "block"
@ -203,16 +203,16 @@ class TestExtractToolCalls:
class TestInitSystemMessage:
def test_init_system_message_empty_prompt_messages(self, runner):
def test_init_system_message_empty_prompt_messages(self, runner: FunctionCallAgentRunner):
result = runner._init_system_message("system", [])
assert len(result) == 1
def test_init_system_message_insert_at_start(self, runner):
def test_init_system_message_insert_at_start(self, runner: FunctionCallAgentRunner):
msgs = [MagicMock()]
result = runner._init_system_message("system", msgs)
assert result[0].content == "system"
def test_init_system_message_no_template(self, runner):
def test_init_system_message_no_template(self, runner: FunctionCallAgentRunner):
result = runner._init_system_message("", [])
assert result == []
@ -223,15 +223,15 @@ class TestInitSystemMessage:
class TestOrganizeUserQuery:
def test_without_files(self, runner):
def test_without_files(self, runner: FunctionCallAgentRunner):
result = runner._organize_user_query("query", [])
assert len(result) == 1
def test_with_none_query(self, runner):
def test_with_none_query(self, runner: FunctionCallAgentRunner):
result = runner._organize_user_query(None, [])
assert len(result) == 1
def test_with_files_uses_image_detail_config(self, runner, mocker: MockerFixture):
def test_with_files_uses_image_detail_config(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
file_content = TextPromptMessageContent(data="file-content")
mock_to_prompt = mocker.patch(
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
@ -255,7 +255,7 @@ class TestOrganizeUserQuery:
class TestClearUserPromptImageMessages:
def test_clear_text_and_image_content(self, runner):
def test_clear_text_and_image_content(self, runner: FunctionCallAgentRunner):
text = MagicMock()
text.type = "text"
text.data = "hello"
@ -271,7 +271,7 @@ class TestClearUserPromptImageMessages:
result = runner._clear_user_prompt_image_messages([user_msg])
assert isinstance(result, list)
def test_clear_includes_file_placeholder(self, runner):
def test_clear_includes_file_placeholder(self, runner: FunctionCallAgentRunner):
text = TextPromptMessageContent(data="hello")
image = ImagePromptMessageContent(format="url", mime_type="image/png")
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
@ -289,7 +289,7 @@ class TestClearUserPromptImageMessages:
class TestRunMethod:
def test_run_non_streaming_no_tool_calls(self, runner):
def test_run_non_streaming_no_tool_calls(self, runner: FunctionCallAgentRunner):
message = MagicMock(id="m1")
dummy_message = DummyMessage(content="hello")
result = DummyResult(message=dummy_message, usage=build_usage())
@ -303,7 +303,7 @@ class TestRunMethod:
queue_calls = runner.queue_manager.publish.call_args_list
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
def test_run_streaming_branch(self, runner):
def test_run_streaming_branch(self, runner: FunctionCallAgentRunner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
@ -318,7 +318,7 @@ class TestRunMethod:
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_streaming_tool_calls_list_content(self, runner):
def test_run_streaming_tool_calls_list_content(self, runner: FunctionCallAgentRunner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
@ -341,7 +341,7 @@ class TestRunMethod:
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_non_streaming_list_content(self, runner):
def test_run_non_streaming_list_content(self, runner: FunctionCallAgentRunner):
message = MagicMock(id="m1")
content = [TextPromptMessageContent(data="hi")]
dummy_message = DummyMessage(content=content)
@ -353,7 +353,7 @@ class TestRunMethod:
assert len(outputs) == 1
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker: MockerFixture):
def test_run_streaming_tool_call_inputs_type_error(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
message = MagicMock(id="m1")
runner.stream_tool_call = True
@ -381,7 +381,7 @@ class TestRunMethod:
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_with_missing_tool_instance(self, runner):
def test_run_with_missing_tool_instance(self, runner: FunctionCallAgentRunner):
message = MagicMock(id="m1")
tool_call = MagicMock()
@ -399,7 +399,7 @@ class TestRunMethod:
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_with_tool_instance_and_files(self, runner, mocker: MockerFixture):
def test_run_with_tool_instance_and_files(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
message = MagicMock(id="m1")
tool_call = MagicMock()
@ -434,7 +434,7 @@ class TestRunMethod:
for call in runner.queue_manager.publish.call_args_list
)
def test_run_max_iteration_error(self, runner):
def test_run_max_iteration_error(self, runner: FunctionCallAgentRunner):
runner.app_config.agent.max_iteration = 0
message = MagicMock(id="m1")

View File

@ -36,7 +36,7 @@ def generator(mocker: MockerFixture) -> AgentAppGenerator:
class TestGenerateGuards:
def test_rejects_blocking_mode(self, generator, mocker: MockerFixture):
def test_rejects_blocking_mode(self, generator: AgentAppGenerator, mocker: MockerFixture):
with pytest.raises(AgentAppGeneratorError, match="only supports streaming"):
generator.generate(
app_model=mocker.MagicMock(),
@ -46,7 +46,7 @@ class TestGenerateGuards:
streaming=False,
)
def test_requires_query(self, generator, mocker: MockerFixture):
def test_requires_query(self, generator: AgentAppGenerator, mocker: MockerFixture):
with pytest.raises(AgentAppGeneratorError, match="query is required"):
generator.generate(
app_model=mocker.MagicMock(),
@ -55,7 +55,7 @@ class TestGenerateGuards:
invoke_from=InvokeFrom.WEB_APP,
)
def test_rejects_blank_query(self, generator, mocker: MockerFixture):
def test_rejects_blank_query(self, generator: AgentAppGenerator, mocker: MockerFixture):
with pytest.raises(AgentAppGeneratorError, match="query is required"):
generator.generate(
app_model=mocker.MagicMock(),
@ -113,7 +113,7 @@ class TestGenerateSuccess:
thread_obj.start.assert_called_once()
generator._resolve_agent.assert_called_once_with(app_model)
def test_generate_loads_existing_conversation(self, generator, mocker: MockerFixture):
def test_generate_loads_existing_conversation(self, generator: AgentAppGenerator, mocker: MockerFixture):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent")
generator._resolve_agent = mocker.MagicMock(
return_value=(mocker.MagicMock(id="a"), mocker.MagicMock(id="s"), mocker.MagicMock())
@ -144,7 +144,9 @@ class TestGenerateSuccess:
get_conv.assert_called_once()
def test_generate_does_not_include_trace_session_id_in_extras(self, generator, mocker: MockerFixture):
def test_generate_does_not_include_trace_session_id_in_extras(
self, generator: AgentAppGenerator, mocker: MockerFixture
):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent")
user = DummyAccount("user")
@ -190,7 +192,7 @@ class TestGenerateWorker:
mocker.patch("libs.flask_utils.preserve_flask_contexts", ctx_manager)
def _wire(self, generator, mocker: MockerFixture, *, run_side_effect=None, handled=False):
def _wire(self, generator: AgentAppGenerator, mocker: MockerFixture, *, run_side_effect=None, handled=False):
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock(id="conv"))
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock(id="msg"))
generator._run_input_guards = mocker.MagicMock(return_value=(handled, "query"))
@ -238,7 +240,7 @@ class TestGenerateWorker:
is_resume=is_resume,
)
def test_happy_path_runs_backend(self, generator, mocker: MockerFixture):
def test_happy_path_runs_backend(self, generator: AgentAppGenerator, mocker: MockerFixture):
runner = self._wire(generator, mocker)
queue_manager = mocker.MagicMock()
self._call(generator, mocker, queue_manager)
@ -280,7 +282,7 @@ class TestGenerateWorker:
self._call(generator, mocker, queue_manager)
queue_manager.publish_error.assert_not_called()
def test_unexpected_error_is_published(self, generator, mocker: MockerFixture):
def test_unexpected_error_is_published(self, generator: AgentAppGenerator, mocker: MockerFixture):
self._wire(generator, mocker, run_side_effect=ValueError("boom"))
queue_manager = mocker.MagicMock()
self._call(generator, mocker, queue_manager)

View File

@ -14,7 +14,7 @@ def runner():
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner, mocker: MockerFixture):
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
@ -23,7 +23,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_moderation_error_direct_output(self, runner, mocker: MockerFixture):
def test_run_moderation_error_direct_output(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -46,7 +46,7 @@ class TestAgentChatAppRunnerRun:
runner.direct_output.assert_called_once()
def test_run_annotation_reply_short_circuits(self, runner, mocker: MockerFixture):
def test_run_annotation_reply_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -75,7 +75,7 @@ class TestAgentChatAppRunnerRun:
queue_manager.publish.assert_called_once()
runner.direct_output.assert_called_once()
def test_run_hosting_moderation_short_circuits(self, runner, mocker: MockerFixture):
def test_run_hosting_moderation_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
@ -99,7 +99,7 @@ class TestAgentChatAppRunnerRun:
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_model_schema_missing(self, runner, mocker: MockerFixture):
def test_run_model_schema_missing(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -141,7 +141,7 @@ class TestAgentChatAppRunnerRun:
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
],
)
def test_run_chain_of_thought_modes(self, runner, mocker: MockerFixture, mode, expected_runner):
def test_run_chain_of_thought_modes(self, runner: AgentChatAppRunner, mocker: MockerFixture, mode, expected_runner):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -197,7 +197,7 @@ class TestAgentChatAppRunnerRun:
runner_instance.run.assert_called_once()
runner._handle_invoke_result.assert_called_once()
def test_run_invalid_llm_mode_raises(self, runner, mocker: MockerFixture):
def test_run_invalid_llm_mode_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -243,7 +243,9 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker: MockerFixture):
def test_run_function_calling_strategy_selected_by_features(
self, runner: AgentChatAppRunner, mocker: MockerFixture
):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
@ -299,7 +301,7 @@ class TestAgentChatAppRunnerRun:
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
runner_instance.run.assert_called_once()
def test_run_conversation_not_found(self, runner, mocker: MockerFixture):
def test_run_conversation_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
@ -333,7 +335,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_message_not_found(self, runner, mocker: MockerFixture):
def test_run_message_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
@ -367,7 +369,7 @@ class TestAgentChatAppRunnerRun:
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_invalid_agent_strategy_raises(self, runner, mocker: MockerFixture):
def test_run_invalid_agent_strategy_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")

View File

@ -112,7 +112,7 @@ def test_run_uses_single_node_execution_branch(
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
def test_single_node_run_validates_target_node_config(monkeypatch: pytest.MonkeyPatch) -> None:
runner = WorkflowBasedAppRunner(
queue_manager=MagicMock(spec=AppQueueManager),
variable_loader=MagicMock(),

View File

@ -218,7 +218,7 @@ class TestWorkflowPersistenceLayer:
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
assert trace_tasks
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch):
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch: pytest.MonkeyPatch):
trace_tasks: list[TraceTask] = []
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
layer, _, _, _ = _make_layer(

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
import pytest
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
@ -112,7 +113,7 @@ def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope
assert "upload_files.id IN" in whereclause
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None:
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch: pytest.MonkeyPatch) -> None:
tenant_id = str(uuid4())
dataset_id = str(uuid4())
document_id = str(uuid4())

View File

@ -64,7 +64,9 @@ class TestCacheEmbeddingMultimodalDocuments:
usage=usage,
)
def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result):
def test_embed_single_multimodal_document_cache_miss(
self, mock_model_instance, sample_multimodal_result: EmbeddingResult
):
"""Test embedding a single multimodal document when cache is empty."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "file123", "content": "test content"}]

View File

@ -97,7 +97,9 @@ class TestRerankModelRunner:
),
]
def test_basic_reranking(self, rerank_runner, mock_model_instance, sample_documents):
def test_basic_reranking(
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
):
"""Test basic reranking with cross-encoder model.
Verifies:
@ -135,7 +137,9 @@ class TestRerankModelRunner:
assert result[3].metadata["score"] == 0.65
assert result[0].page_content == sample_documents[2].page_content
def test_score_threshold_filtering(self, rerank_runner, mock_model_instance, sample_documents):
def test_score_threshold_filtering(
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
):
"""Test score threshold filtering.
Verifies:
@ -163,7 +167,9 @@ class TestRerankModelRunner:
assert result[0].metadata["score"] == 0.90
assert result[1].metadata["score"] == 0.70
def test_top_k_selection(self, rerank_runner, mock_model_instance, sample_documents):
def test_top_k_selection(
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
):
"""Test top-k selection functionality.
Verifies:
@ -191,7 +197,7 @@ class TestRerankModelRunner:
assert result[0].metadata["score"] == 0.95
assert result[1].metadata["score"] == 0.85
def test_document_deduplication_dify_provider(self, rerank_runner, mock_model_instance):
def test_document_deduplication_dify_provider(self, rerank_runner: RerankModelRunner, mock_model_instance):
"""Test document deduplication for dify provider.
Verifies:
@ -235,7 +241,7 @@ class TestRerankModelRunner:
assert len(call_kwargs["docs"]) == 2 # Duplicate removed
assert len(result) == 2
def test_document_deduplication_external_provider(self, rerank_runner, mock_model_instance):
def test_document_deduplication_external_provider(self, rerank_runner: RerankModelRunner, mock_model_instance):
"""Test document deduplication for external provider.
Verifies:
@ -273,7 +279,9 @@ class TestRerankModelRunner:
assert len(call_kwargs["docs"]) == 2
assert len(result) == 2
def test_combined_threshold_and_top_k(self, rerank_runner, mock_model_instance, sample_documents):
def test_combined_threshold_and_top_k(
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
):
"""Test combined score threshold and top-k selection.
Verifies:
@ -307,7 +315,9 @@ class TestRerankModelRunner:
assert result[0].metadata["score"] == 0.95
assert result[1].metadata["score"] == 0.85
def test_metadata_preservation(self, rerank_runner, mock_model_instance, sample_documents):
def test_metadata_preservation(
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
):
"""Test that original metadata is preserved after reranking.
Verifies:
@ -334,7 +344,7 @@ class TestRerankModelRunner:
assert result[0].metadata["score"] == 0.90
assert result[0].provider == "dify"
def test_empty_documents_list(self, rerank_runner, mock_model_instance):
def test_empty_documents_list(self, rerank_runner: RerankModelRunner, mock_model_instance):
"""Test handling of empty documents list.
Verifies:
@ -523,7 +533,9 @@ class TestRerankModelRunnerMultimodal:
docs_arg = mock_text_rerank.call_args.args[1]
assert len(docs_arg) == 1
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance):
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(
self, rerank_runner: RerankModelRunner, mock_model_instance
):
text_doc = Document(
page_content="text-content",
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},

View File

@ -413,7 +413,7 @@ class TestWorkflowGeneratorTransientRetry:
}
)
def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch):
def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch: pytest.MonkeyPatch):
# The planner's first invoke raises a transient connection error; the
# retry succeeds and the pipeline completes normally. Sleep is patched
# out so the test doesn't actually wait for the backoff.
@ -443,7 +443,7 @@ class TestWorkflowGeneratorTransientRetry:
# planner (failed once + retried) + builder = 3 invocations total.
assert model_instance.invoke_llm.call_count == 3
def test_gives_up_after_exhausting_transient_retries(self, monkeypatch):
def test_gives_up_after_exhausting_transient_retries(self, monkeypatch: pytest.MonkeyPatch):
# Every attempt hits the transient error — once we exhaust the retry
# budget the failure surfaces as a normal error envelope rather than
# hanging or looping forever.
@ -471,7 +471,7 @@ class TestWorkflowGeneratorTransientRetry:
assert model_instance.invoke_llm.call_count == _INVOKE_MAX_ATTEMPTS
def test_does_not_retry_permanent_invoke_error(self, monkeypatch):
def test_does_not_retry_permanent_invoke_error(self, monkeypatch: pytest.MonkeyPatch):
# An auth error is permanent — retrying just burns latency and quota.
# The runner must fail on the first attempt.
# If the code wrongly slept here we'd want the test to still be fast;

View File

@ -332,7 +332,9 @@ def test_email_delivery_method_extracts_variable_selectors() -> None:
assert method.extract_variable_selectors() == [["start", "name"]]
def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors(monkeypatch) -> None:
def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors(
monkeypatch: pytest.MonkeyPatch,
) -> None:
method = EmailDeliveryMethod(
enabled=True,
config=EmailDeliveryConfig(

View File

@ -1,8 +1,10 @@
from pathlib import Path
import pandas as pd
import pytest
def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch):
def test_pandas_csv(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
@ -17,7 +19,7 @@ def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch):
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch):
def test_pandas_xlsx(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
@ -32,7 +34,7 @@ def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch):
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch: pytest.MonkeyPatch):
def test_pandas_xlsx_with_sheets(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data1)

View File

@ -75,7 +75,7 @@ def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build):
@patch("libs.rate_limit._build_limiter")
def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch):
def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch: pytest.MonkeyPatch):
# 0 disables the limit — short-circuit before building/consulting a limiter.
monkeypatch.setattr(
"libs.rate_limit.LIMIT_BEARER_PER_TOKEN",

View File

@ -17,6 +17,8 @@ from unittest.mock import Mock, patch
from urllib.parse import parse_qs, urlparse
from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import (
AppDatasetJoin,
@ -703,7 +705,7 @@ class TestDocumentSegmentIndexing:
# Assert
assert segment.hit_count == 5
def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch):
def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch: pytest.MonkeyPatch):
"""Test attachment source URLs use FILES_URL before falling back to CONSOLE_API_URL."""
# Arrange
segment = DocumentSegment(

View File

@ -125,7 +125,9 @@ def test_rebuild_serialized_graph_files_without_lookup_preserves_scalar_values()
assert rebuild_serialized_graph_files_without_lookup("plain-text") == "plain-text"
def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup(monkeypatch) -> None:
def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup(
monkeypatch: pytest.MonkeyPatch,
) -> None:
rebuilt_file = File(
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,

View File

@ -2,6 +2,8 @@ import json
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from models.snippet import CustomizedSnippet
@ -11,7 +13,7 @@ def test_graph_dict_returns_empty_without_workflow_id() -> None:
assert snippet.graph_dict == {}
def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None:
def test_graph_dict_loads_published_workflow_graph(monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"id": "llm-1"}], "edges": []}))
session = SimpleNamespace(get=Mock(return_value=workflow))
monkeypatch.setattr("models.snippet.db.session", session)
@ -21,7 +23,7 @@ def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None:
session.get.assert_called_once()
def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch) -> None:
def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch: pytest.MonkeyPatch) -> None:
session = SimpleNamespace(get=Mock(return_value=None))
monkeypatch.setattr("models.snippet.db.session", session)
snippet = CustomizedSnippet(workflow_id="missing-workflow")
@ -36,7 +38,7 @@ def test_input_fields_list_parses_json_or_returns_empty() -> None:
]
def test_tags_returns_query_results_or_empty(monkeypatch) -> None:
def test_tags_returns_query_results_or_empty(monkeypatch: pytest.MonkeyPatch) -> None:
tags = [SimpleNamespace(id="tag-1")]
session = SimpleNamespace(scalars=Mock(return_value=SimpleNamespace(all=Mock(return_value=tags))))
monkeypatch.setattr("models.snippet.db.session", session)
@ -48,7 +50,7 @@ def test_tags_returns_query_results_or_empty(monkeypatch) -> None:
assert snippet.tags == []
def test_account_properties_and_author_name(monkeypatch) -> None:
def test_account_properties_and_author_name(monkeypatch: pytest.MonkeyPatch) -> None:
account = SimpleNamespace(id="account-1", name="Ada")
updated_account = SimpleNamespace(id="account-2", name="Grace")
session = SimpleNamespace(

View File

@ -1324,7 +1324,7 @@ class TestAgentAppBackingAgent:
with pytest.raises(roster_service.AgentNotFoundError):
service.get_agent_app_model(tenant_id="tenant-1", agent_id="agent-x")
def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch):
def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch: pytest.MonkeyPatch):
source_config = SimpleNamespace(
opening_statement="hello",
suggested_questions='["q1"]',
@ -1463,7 +1463,7 @@ class TestAgentAppBackingAgent:
assert target_agent.updated_by == "account-1"
assert session.commits == 1
def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch):
def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch: pytest.MonkeyPatch):
source_app = SimpleNamespace(
id="source-app",
tenant_id="tenant-1",
@ -1522,7 +1522,7 @@ class TestAgentAppBackingAgent:
assert duplicated is target_app
assert access_mode_updates == [("target-app", "private")]
def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch):
def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch: pytest.MonkeyPatch):
source_app = SimpleNamespace(
id="source-app",
tenant_id="tenant-1",

View File

@ -66,7 +66,7 @@ class TestFirecrawlAuth:
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock, auth_instance: FirecrawlAuth):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
@ -97,7 +97,9 @@ class TestFirecrawlAuth:
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
def test_should_handle_http_errors(
self, mock_post: MagicMock, status_code, error_message, auth_instance: FirecrawlAuth
):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
mock_response.status_code = status_code
@ -120,7 +122,13 @@ class TestFirecrawlAuth:
)
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
self,
mock_post: MagicMock,
status_code,
response_text,
has_json_error,
expected_error_contains,
auth_instance: FirecrawlAuth,
):
"""Test handling of unexpected errors with various response formats"""
mock_response = MagicMock()
@ -146,7 +154,9 @@ class TestFirecrawlAuth:
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
def test_should_handle_network_errors(
self, mock_post: MagicMock, exception_type, exception_message, auth_instance: FirecrawlAuth
):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
@ -186,7 +196,7 @@ class TestFirecrawlAuth:
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
def test_should_handle_timeout_with_retry_suggestion(self, mock_post: MagicMock, auth_instance: FirecrawlAuth):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")

View File

@ -36,7 +36,7 @@ class TestJinaAuth:
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina._http_client.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post):
def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
@ -54,7 +54,7 @@ class TestJinaAuth:
)
@patch("services.auth.jina.jina._http_client.post", autospec=True)
def test_should_handle_http_402_error(self, mock_post):
def test_should_handle_http_402_error(self, mock_post: MagicMock):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
mock_response.status_code = 402
@ -100,7 +100,7 @@ class TestJinaAuth:
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina._http_client.post", autospec=True)
def test_should_handle_http_500_error(self, mock_post):
def test_should_handle_http_500_error(self, mock_post: MagicMock):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
mock_response.status_code = 500
@ -115,7 +115,7 @@ class TestJinaAuth:
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina._http_client.post", autospec=True)
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
def test_should_handle_unexpected_error_with_text_response(self, mock_post: MagicMock):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
mock_response.status_code = 403
@ -163,7 +163,7 @@ class TestJinaAuth:
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina._http_client.post", autospec=True)
def test_should_handle_network_errors(self, mock_post):
def test_should_handle_network_errors(self, mock_post: MagicMock):
"""Test handling of network connection errors"""
mock_post.side_effect = httpx.ConnectError("Network error")

View File

@ -89,7 +89,7 @@ class _FakeWriteSession:
class TestUpdateFeatures:
def test_persists_new_app_model_config_version(self, monkeypatch):
def test_persists_new_app_model_config_version(self, monkeypatch: pytest.MonkeyPatch):
session = _FakeWriteSession()
monkeypatch.setattr(svc_mod.db, "session", session)
app_model = SimpleNamespace(

View File

@ -137,14 +137,14 @@ class ExternalDatasetServiceTestDataFactory:
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return ExternalDatasetServiceTestDataFactory
return ExternalDatasetServiceTestDataFactory()
class TestExternalDatasetServiceGetAPIs:
"""Test get_external_knowledge_apis operations - comprehensive coverage."""
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_success_basic(self, mock_db, factory):
def test_get_external_knowledge_apis_success_basic(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful retrieval of external knowledge APIs with pagination."""
# Arrange
tenant_id = "tenant-123"
@ -171,7 +171,9 @@ class TestExternalDatasetServiceGetAPIs:
mock_db.paginate.assert_called_once()
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_with_search_filter(self, mock_db, factory):
def test_get_external_knowledge_apis_with_search_filter(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test retrieval with search filter."""
# Arrange
tenant_id = "tenant-123"
@ -195,7 +197,7 @@ class TestExternalDatasetServiceGetAPIs:
assert result_items[0].name == "Production API"
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_empty_results(self, mock_db, factory):
def test_get_external_knowledge_apis_empty_results(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test retrieval with no results."""
# Arrange
mock_pagination = MagicMock()
@ -213,7 +215,9 @@ class TestExternalDatasetServiceGetAPIs:
assert result_total == 0
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_large_result_set(self, mock_db, factory):
def test_get_external_knowledge_apis_large_result_set(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test retrieval with large result set."""
# Arrange
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)]
@ -233,7 +237,9 @@ class TestExternalDatasetServiceGetAPIs:
assert result_total == 100
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_pagination_last_page(self, mock_db, factory):
def test_get_external_knowledge_apis_pagination_last_page(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test last page pagination with partial results."""
# Arrange
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(95, 100)]
@ -253,7 +259,9 @@ class TestExternalDatasetServiceGetAPIs:
assert result_total == 100
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_case_insensitive_search(self, mock_db, factory):
def test_get_external_knowledge_apis_case_insensitive_search(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test case-insensitive search functionality."""
# Arrange
apis = [
@ -276,7 +284,9 @@ class TestExternalDatasetServiceGetAPIs:
assert result_total == 2
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_special_characters_search(self, mock_db, factory):
def test_get_external_knowledge_apis_special_characters_search(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test search with special characters."""
# Arrange
apis = [factory.create_external_knowledge_api_mock(name="API-v2.0 (beta)")]
@ -295,7 +305,9 @@ class TestExternalDatasetServiceGetAPIs:
assert len(result_items) == 1
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_max_per_page_limit(self, mock_db, factory):
def test_get_external_knowledge_apis_max_per_page_limit(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that max_per_page limit is enforced."""
# Arrange
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)]
@ -315,7 +327,9 @@ class TestExternalDatasetServiceGetAPIs:
assert call_args.kwargs["max_per_page"] == 100
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_apis_ordered_by_created_at_desc(self, mock_db, factory):
def test_get_external_knowledge_apis_ordered_by_created_at_desc(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that results are ordered by created_at descending."""
# Arrange
apis = [
@ -340,7 +354,7 @@ class TestExternalDatasetServiceGetAPIs:
class TestExternalDatasetServiceValidateAPIList:
"""Test validate_api_list operations."""
def test_validate_api_list_success_with_all_fields(self, factory):
def test_validate_api_list_success_with_all_fields(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful validation with all required fields."""
# Arrange
api_settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"}
@ -348,7 +362,7 @@ class TestExternalDatasetServiceValidateAPIList:
# Act & Assert - should not raise
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_missing_endpoint(self, factory):
def test_validate_api_list_missing_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when endpoint is missing."""
# Arrange
api_settings = {"api_key": "test-key"}
@ -357,7 +371,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="endpoint is required"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_empty_endpoint(self, factory):
def test_validate_api_list_empty_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when endpoint is empty string."""
# Arrange
api_settings = {"endpoint": "", "api_key": "test-key"}
@ -366,7 +380,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="endpoint is required"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_missing_api_key(self, factory):
def test_validate_api_list_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when API key is missing."""
# Arrange
api_settings = {"endpoint": "https://api.example.com"}
@ -375,7 +389,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_empty_api_key(self, factory):
def test_validate_api_list_empty_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when API key is empty string."""
# Arrange
api_settings = {"endpoint": "https://api.example.com", "api_key": ""}
@ -384,7 +398,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_empty_dict(self, factory):
def test_validate_api_list_empty_dict(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when settings are empty dict."""
# Arrange
api_settings = {}
@ -393,7 +407,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="api list is empty"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_none_value(self, factory):
def test_validate_api_list_none_value(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when settings are None."""
# Arrange
api_settings = None
@ -402,7 +416,7 @@ class TestExternalDatasetServiceValidateAPIList:
with pytest.raises(ValueError, match="api list is empty"):
ExternalDatasetService.validate_api_list(api_settings)
def test_validate_api_list_with_extra_fields(self, factory):
def test_validate_api_list_with_extra_fields(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation succeeds with extra fields present."""
# Arrange
api_settings = {
@ -421,7 +435,9 @@ class TestExternalDatasetServiceCreateAPI:
@patch("services.external_knowledge_service.db")
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
def test_create_external_knowledge_api_success_full(self, mock_check, mock_db, factory):
def test_create_external_knowledge_api_success_full(
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test successful creation with all fields."""
# Arrange
tenant_id = "tenant-123"
@ -447,7 +463,9 @@ class TestExternalDatasetServiceCreateAPI:
@patch("services.external_knowledge_service.db")
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
def test_create_external_knowledge_api_minimal_fields(self, mock_check, mock_db, factory):
def test_create_external_knowledge_api_minimal_fields(
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test creation with minimal required fields."""
# Arrange
args = {
@ -463,7 +481,9 @@ class TestExternalDatasetServiceCreateAPI:
assert result.description == ""
@patch("services.external_knowledge_service.db")
def test_create_external_knowledge_api_missing_settings(self, mock_db, factory):
def test_create_external_knowledge_api_missing_settings(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test creation fails when settings are missing."""
# Arrange
args = {"name": "Test API", "description": "Test"}
@ -473,7 +493,7 @@ class TestExternalDatasetServiceCreateAPI:
ExternalDatasetService.create_external_knowledge_api("tenant-123", "user-123", args)
@patch("services.external_knowledge_service.db")
def test_create_external_knowledge_api_none_settings(self, mock_db, factory):
def test_create_external_knowledge_api_none_settings(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test creation fails when settings are explicitly None."""
# Arrange
args = {"name": "Test API", "settings": None}
@ -484,7 +504,9 @@ class TestExternalDatasetServiceCreateAPI:
@patch("services.external_knowledge_service.db")
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
def test_create_external_knowledge_api_settings_json_serialization(self, mock_check, mock_db, factory):
def test_create_external_knowledge_api_settings_json_serialization(
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that settings are properly JSON serialized."""
# Arrange
settings = {
@ -504,7 +526,9 @@ class TestExternalDatasetServiceCreateAPI:
@patch("services.external_knowledge_service.db")
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
def test_create_external_knowledge_api_unicode_handling(self, mock_check, mock_db, factory):
def test_create_external_knowledge_api_unicode_handling(
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test proper handling of Unicode characters in name and description."""
# Arrange
args = {
@ -522,7 +546,9 @@ class TestExternalDatasetServiceCreateAPI:
@patch("services.external_knowledge_service.db")
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
def test_create_external_knowledge_api_long_description(self, mock_check, mock_db, factory):
def test_create_external_knowledge_api_long_description(
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test creation with very long description."""
# Arrange
long_description = "A" * 1000
@ -544,7 +570,7 @@ class TestExternalDatasetServiceCheckEndpoint:
"""Test check_endpoint_and_api_key operations - extensive coverage."""
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_success_https(self, mock_proxy, factory):
def test_check_endpoint_success_https(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful validation with HTTPS endpoint."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -558,7 +584,7 @@ class TestExternalDatasetServiceCheckEndpoint:
mock_proxy.post.assert_called_once()
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_success_http(self, mock_proxy, factory):
def test_check_endpoint_success_http(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful validation with HTTP endpoint."""
# Arrange
settings = {"endpoint": "http://api.example.com", "api_key": "test-key"}
@ -570,7 +596,7 @@ class TestExternalDatasetServiceCheckEndpoint:
# Act & Assert - should not raise
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_missing_endpoint_key(self, factory):
def test_check_endpoint_missing_endpoint_key(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when endpoint key is missing."""
# Arrange
settings = {"api_key": "test-key"}
@ -579,7 +605,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="endpoint is required"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_empty_endpoint_string(self, factory):
def test_check_endpoint_empty_endpoint_string(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when endpoint is empty string."""
# Arrange
settings = {"endpoint": "", "api_key": "test-key"}
@ -588,7 +614,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="endpoint is required"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_whitespace_endpoint(self, factory):
def test_check_endpoint_whitespace_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when endpoint is only whitespace."""
# Arrange
settings = {"endpoint": " ", "api_key": "test-key"}
@ -597,7 +623,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="invalid endpoint"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_missing_api_key_key(self, factory):
def test_check_endpoint_missing_api_key_key(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when api_key key is missing."""
# Arrange
settings = {"endpoint": "https://api.example.com"}
@ -606,7 +632,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_empty_api_key_string(self, factory):
def test_check_endpoint_empty_api_key_string(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when api_key is empty string."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": ""}
@ -615,7 +641,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_no_scheme_url(self, factory):
def test_check_endpoint_no_scheme_url(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails for URL without http:// or https://."""
# Arrange
settings = {"endpoint": "api.example.com", "api_key": "test-key"}
@ -624,7 +650,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="invalid endpoint.*must start with http"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_invalid_scheme(self, factory):
def test_check_endpoint_invalid_scheme(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails for URL with invalid scheme."""
# Arrange
settings = {"endpoint": "ftp://api.example.com", "api_key": "test-key"}
@ -633,7 +659,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="failed to connect to the endpoint"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_no_netloc(self, factory):
def test_check_endpoint_no_netloc(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails for URL without network location."""
# Arrange
settings = {"endpoint": "http://", "api_key": "test-key"}
@ -642,7 +668,7 @@ class TestExternalDatasetServiceCheckEndpoint:
with pytest.raises(ValueError, match="invalid endpoint"):
ExternalDatasetService.check_endpoint_and_api_key(settings)
def test_check_endpoint_malformed_url(self, factory):
def test_check_endpoint_malformed_url(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails for malformed URL."""
# Arrange
settings = {"endpoint": "https:///invalid", "api_key": "test-key"}
@ -652,7 +678,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_connection_timeout(self, mock_proxy, factory):
def test_check_endpoint_connection_timeout(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails on connection timeout."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -663,7 +689,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_network_error(self, mock_proxy, factory):
def test_check_endpoint_network_error(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails on network error."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -674,7 +700,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory):
def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails with 502 Bad Gateway."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -688,7 +714,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_404_not_found(self, mock_proxy, factory):
def test_check_endpoint_404_not_found(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails with 404 Not Found."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -702,7 +728,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_403_forbidden(self, mock_proxy, factory):
def test_check_endpoint_403_forbidden(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails with 403 Forbidden (auth failure)."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "wrong-key"}
@ -716,7 +742,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory):
def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test that other 4xx codes don't raise exceptions."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -730,7 +756,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory):
def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test that 5xx codes except 502 don't raise exceptions."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
@ -744,7 +770,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_with_port_number(self, mock_proxy, factory):
def test_check_endpoint_with_port_number(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation with endpoint including port number."""
# Arrange
settings = {"endpoint": "https://api.example.com:8443", "api_key": "test-key"}
@ -757,7 +783,7 @@ class TestExternalDatasetServiceCheckEndpoint:
ExternalDatasetService.check_endpoint_and_api_key(settings)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_with_path(self, mock_proxy, factory):
def test_check_endpoint_with_path(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation with endpoint including path."""
# Arrange
settings = {"endpoint": "https://api.example.com/v1/api", "api_key": "test-key"}
@ -773,7 +799,9 @@ class TestExternalDatasetServiceCheckEndpoint:
assert "/retrieval" in call_args[0][0]
@patch("services.external_knowledge_service.ssrf_proxy")
def test_check_endpoint_authorization_header_format(self, mock_proxy, factory):
def test_check_endpoint_authorization_header_format(
self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that Authorization header is properly formatted."""
# Arrange
settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"}
@ -795,7 +823,7 @@ class TestExternalDatasetServiceGetAPI:
"""Test get_external_knowledge_api operations."""
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_success(self, mock_db, factory):
def test_get_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful retrieval of external knowledge API."""
# Arrange
api_id = "api-123"
@ -811,7 +839,7 @@ class TestExternalDatasetServiceGetAPI:
assert result.id == api_id
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
def test_get_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when API is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -826,7 +854,9 @@ class TestExternalDatasetServiceUpdateAPI:
@patch("services.external_knowledge_service.naive_utc_now")
@patch("services.external_knowledge_service.db")
def test_update_external_knowledge_api_success_all_fields(self, mock_db, mock_now, factory):
def test_update_external_knowledge_api_success_all_fields(
self, mock_db, mock_now, factory: ExternalDatasetServiceTestDataFactory
):
"""Test successful update with all fields."""
# Arrange
api_id = "api-123"
@ -856,7 +886,9 @@ class TestExternalDatasetServiceUpdateAPI:
mock_db.session.commit.assert_called_once()
@patch("services.external_knowledge_service.db")
def test_update_external_knowledge_api_preserve_hidden_api_key(self, mock_db, factory):
def test_update_external_knowledge_api_preserve_hidden_api_key(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that hidden API key is preserved from existing settings."""
# Arrange
api_id = "api-123"
@ -883,7 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
assert settings["api_key"] == "original-secret-key"
@patch("services.external_knowledge_service.db")
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
def test_update_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when API is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -895,7 +927,9 @@ class TestExternalDatasetServiceUpdateAPI:
ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
@patch("services.external_knowledge_service.db")
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
def test_update_external_knowledge_api_tenant_mismatch(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_db.session.scalar.return_value = None
@ -907,7 +941,7 @@ class TestExternalDatasetServiceUpdateAPI:
ExternalDatasetService.update_external_knowledge_api("wrong-tenant", "user-123", "api-123", args)
@patch("services.external_knowledge_service.db")
def test_update_external_knowledge_api_name_only(self, mock_db, factory):
def test_update_external_knowledge_api_name_only(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test updating only the name field."""
# Arrange
existing_api = factory.create_external_knowledge_api_mock(
@ -930,7 +964,7 @@ class TestExternalDatasetServiceDeleteAPI:
"""Test delete_external_knowledge_api operations."""
@patch("services.external_knowledge_service.db")
def test_delete_external_knowledge_api_success(self, mock_db, factory):
def test_delete_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful deletion of external knowledge API."""
# Arrange
api_id = "api-123"
@ -948,7 +982,7 @@ class TestExternalDatasetServiceDeleteAPI:
mock_db.session.commit.assert_called_once()
@patch("services.external_knowledge_service.db")
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
def test_delete_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when API is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -958,7 +992,9 @@ class TestExternalDatasetServiceDeleteAPI:
ExternalDatasetService.delete_external_knowledge_api("tenant-123", "api-123")
@patch("services.external_knowledge_service.db")
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
def test_delete_external_knowledge_api_tenant_mismatch(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_db.session.scalar.return_value = None
@ -972,7 +1008,9 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test external_knowledge_api_use_check operations."""
@patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_in_use_single(self, mock_db, factory):
def test_external_knowledge_api_use_check_in_use_single(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test API use check when API has one binding."""
# Arrange
api_id = "api-123"
@ -989,7 +1027,9 @@ class TestExternalDatasetServiceAPIUseCheck:
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
@patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
def test_external_knowledge_api_use_check_in_use_multiple(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test API use check with multiple bindings."""
# Arrange
api_id = "api-123"
@ -1005,7 +1045,7 @@ class TestExternalDatasetServiceAPIUseCheck:
assert count == 10
@patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory):
def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test API use check when API is not in use."""
# Arrange
api_id = "api-123"
@ -1025,7 +1065,7 @@ class TestExternalDatasetServiceGetBinding:
"""Test get_external_knowledge_binding_with_dataset_id operations."""
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_binding_success(self, mock_db, factory):
def test_get_external_knowledge_binding_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful retrieval of external knowledge binding."""
# Arrange
tenant_id = "tenant-123"
@ -1043,7 +1083,7 @@ class TestExternalDatasetServiceGetBinding:
assert result.tenant_id == tenant_id
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
def test_get_external_knowledge_binding_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when binding is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -1057,7 +1097,9 @@ class TestExternalDatasetServiceDocumentValidate:
"""Test document_create_args_validate operations."""
@patch("services.external_knowledge_service.db")
def test_document_create_args_validate_success_all_params(self, mock_db, factory):
def test_document_create_args_validate_success_all_params(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test successful validation with all required parameters."""
# Arrange
tenant_id = "tenant-123"
@ -1081,7 +1123,9 @@ class TestExternalDatasetServiceDocumentValidate:
ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter)
@patch("services.external_knowledge_service.db")
def test_document_create_args_validate_missing_required_param(self, mock_db, factory):
def test_document_create_args_validate_missing_required_param(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test validation fails when required parameter is missing."""
# Arrange
tenant_id = "tenant-123"
@ -1100,7 +1144,7 @@ class TestExternalDatasetServiceDocumentValidate:
ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter)
@patch("services.external_knowledge_service.db")
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
def test_document_create_args_validate_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test validation fails when API is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -1110,7 +1154,9 @@ class TestExternalDatasetServiceDocumentValidate:
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@patch("services.external_knowledge_service.db")
def test_document_create_args_validate_no_custom_parameters(self, mock_db, factory):
def test_document_create_args_validate_no_custom_parameters(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test validation succeeds when no custom parameters defined."""
# Arrange
settings = {}
@ -1122,7 +1168,9 @@ class TestExternalDatasetServiceDocumentValidate:
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@patch("services.external_knowledge_service.db")
def test_document_create_args_validate_optional_params_not_required(self, mock_db, factory):
def test_document_create_args_validate_optional_params_not_required(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that optional parameters don't cause validation failure."""
# Arrange
settings = {
@ -1146,7 +1194,7 @@ class TestExternalDatasetServiceProcessAPI:
"""Test process_external_api operations - comprehensive HTTP method coverage."""
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_get_request(self, mock_proxy, factory):
def test_process_external_api_get_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing GET request."""
# Arrange
settings = factory.create_api_setting_mock(request_method="get")
@ -1162,7 +1210,9 @@ class TestExternalDatasetServiceProcessAPI:
mock_proxy.get.assert_called_once()
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_post_request_with_data(self, mock_proxy, factory):
def test_process_external_api_post_request_with_data(
self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory
):
"""Test processing POST request with data."""
# Arrange
settings = factory.create_api_setting_mock(request_method="post", params={"key": "value", "data": "test"})
@ -1180,7 +1230,7 @@ class TestExternalDatasetServiceProcessAPI:
assert "data" in call_kwargs
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_put_request(self, mock_proxy, factory):
def test_process_external_api_put_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing PUT request."""
# Arrange
settings = factory.create_api_setting_mock(request_method="put")
@ -1196,7 +1246,7 @@ class TestExternalDatasetServiceProcessAPI:
mock_proxy.put.assert_called_once()
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_delete_request(self, mock_proxy, factory):
def test_process_external_api_delete_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing DELETE request."""
# Arrange
settings = factory.create_api_setting_mock(request_method="delete")
@ -1212,7 +1262,7 @@ class TestExternalDatasetServiceProcessAPI:
mock_proxy.delete.assert_called_once()
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_patch_request(self, mock_proxy, factory):
def test_process_external_api_patch_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing PATCH request."""
# Arrange
settings = factory.create_api_setting_mock(request_method="patch")
@ -1228,7 +1278,7 @@ class TestExternalDatasetServiceProcessAPI:
mock_proxy.patch.assert_called_once()
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_head_request(self, mock_proxy, factory):
def test_process_external_api_head_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing HEAD request."""
# Arrange
settings = factory.create_api_setting_mock(request_method="head")
@ -1243,7 +1293,7 @@ class TestExternalDatasetServiceProcessAPI:
assert result == mock_response
mock_proxy.head.assert_called_once()
def test_process_external_api_invalid_method(self, factory):
def test_process_external_api_invalid_method(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test error for invalid HTTP method."""
# Arrange
settings = factory.create_api_setting_mock(request_method="INVALID")
@ -1253,7 +1303,7 @@ class TestExternalDatasetServiceProcessAPI:
ExternalDatasetService.process_external_api(settings, None)
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_with_files(self, mock_proxy, factory):
def test_process_external_api_with_files(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test processing request with file uploads."""
# Arrange
settings = factory.create_api_setting_mock(request_method="post")
@ -1272,7 +1322,7 @@ class TestExternalDatasetServiceProcessAPI:
assert call_kwargs["files"] == files
@patch("services.external_knowledge_service.ssrf_proxy")
def test_process_external_api_follow_redirects(self, mock_proxy, factory):
def test_process_external_api_follow_redirects(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
"""Test that follow_redirects is enabled."""
# Arrange
settings = factory.create_api_setting_mock(request_method="get")
@ -1291,7 +1341,7 @@ class TestExternalDatasetServiceProcessAPI:
class TestExternalDatasetServiceAssemblingHeaders:
"""Test assembling_headers operations - comprehensive authorization coverage."""
def test_assembling_headers_bearer_token(self, factory):
def test_assembling_headers_bearer_token(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers with Bearer token."""
# Arrange
authorization = factory.create_authorization_mock(token_type="bearer", api_key="secret-key-123")
@ -1302,7 +1352,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
# Assert
assert result["Authorization"] == "Bearer secret-key-123"
def test_assembling_headers_basic_auth(self, factory):
def test_assembling_headers_basic_auth(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers with Basic authentication."""
# Arrange
authorization = factory.create_authorization_mock(token_type="basic", api_key="credentials")
@ -1313,7 +1363,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
# Assert
assert result["Authorization"] == "Basic credentials"
def test_assembling_headers_custom_auth(self, factory):
def test_assembling_headers_custom_auth(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers with custom authentication."""
# Arrange
authorization = factory.create_authorization_mock(token_type="custom", api_key="custom-token")
@ -1324,7 +1374,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
# Assert
assert result["Authorization"] == "custom-token"
def test_assembling_headers_custom_header_name(self, factory):
def test_assembling_headers_custom_header_name(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers with custom header name."""
# Arrange
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key-123", header="X-API-Key")
@ -1336,7 +1386,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
assert result["X-API-Key"] == "Bearer key-123"
assert "Authorization" not in result
def test_assembling_headers_with_existing_headers(self, factory):
def test_assembling_headers_with_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers preserves existing headers."""
# Arrange
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key")
@ -1355,7 +1405,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
assert result["X-Custom"] == "value"
assert result["User-Agent"] == "TestAgent/1.0"
def test_assembling_headers_empty_existing_headers(self, factory):
def test_assembling_headers_empty_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test assembling headers with empty existing headers dict."""
# Arrange
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key")
@ -1368,7 +1418,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
assert result["Authorization"] == "Bearer key"
assert len(result) == 1
def test_assembling_headers_missing_api_key(self, factory):
def test_assembling_headers_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when API key is missing."""
# Arrange
config = AuthorizationConfig(api_key=None, type="bearer", header="Authorization")
@ -1378,7 +1428,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.assembling_headers(authorization)
def test_assembling_headers_missing_config(self, factory):
def test_assembling_headers_missing_config(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when config is missing."""
# Arrange
authorization = Authorization(type="api-key", config=None)
@ -1387,7 +1437,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
with pytest.raises(ValueError, match="authorization config is required"):
ExternalDatasetService.assembling_headers(authorization)
def test_assembling_headers_default_header_name(self, factory):
def test_assembling_headers_default_header_name(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test that default header name is Authorization when not specified."""
# Arrange
config = AuthorizationConfig(api_key="key", type="bearer", header=None)
@ -1403,7 +1453,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
class TestExternalDatasetServiceGetSettings:
"""Test get_external_knowledge_api_settings operations."""
def test_get_external_knowledge_api_settings_success(self, factory):
def test_get_external_knowledge_api_settings_success(self, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful parsing of API settings."""
# Arrange
settings = {
@ -1428,7 +1478,7 @@ class TestExternalDatasetServiceCreateDataset:
"""Test create_external_dataset operations."""
@patch("services.external_knowledge_service.db")
def test_create_external_dataset_success_full(self, mock_db, factory):
def test_create_external_dataset_success_full(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test successful creation of external dataset with all fields."""
# Arrange
tenant_id = "tenant-123"
@ -1457,7 +1507,9 @@ class TestExternalDatasetServiceCreateDataset:
mock_db.session.commit.assert_called_once()
@patch("services.external_knowledge_service.db")
def test_create_external_dataset_duplicate_name_error(self, mock_db, factory):
def test_create_external_dataset_duplicate_name_error(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when dataset name already exists."""
# Arrange
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
@ -1471,7 +1523,7 @@ class TestExternalDatasetServiceCreateDataset:
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
@patch("services.external_knowledge_service.db")
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
def test_create_external_dataset_api_not_found_error(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
"""Test error when external knowledge API is not found."""
# Arrange
mock_db.session.scalar.side_effect = [None, None]
@ -1483,7 +1535,9 @@ class TestExternalDatasetServiceCreateDataset:
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
@patch("services.external_knowledge_service.db")
def test_create_external_dataset_missing_knowledge_id_error(self, mock_db, factory):
def test_create_external_dataset_missing_knowledge_id_error(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when external_knowledge_id is missing."""
# Arrange
api = factory.create_external_knowledge_api_mock()
@ -1497,7 +1551,9 @@ class TestExternalDatasetServiceCreateDataset:
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
@patch("services.external_knowledge_service.db")
def test_create_external_dataset_missing_api_id_error(self, mock_db, factory):
def test_create_external_dataset_missing_api_id_error(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when external_knowledge_api_id is missing."""
# Arrange
api = factory.create_external_knowledge_api_mock()
@ -1516,7 +1572,9 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_success_with_results(self, mock_db, mock_process, factory):
def test_fetch_external_knowledge_retrieval_success_with_results(
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
):
"""Test successful external knowledge retrieval with results."""
# Arrange
tenant_id = "tenant-123"
@ -1553,7 +1611,9 @@ class TestExternalDatasetServiceFetchRetrieval:
assert result[1]["score"] == 0.8
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
def test_fetch_external_knowledge_retrieval_binding_not_found_error(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when external knowledge binding is not found."""
# Arrange
mock_db.session.scalar.return_value = None
@ -1563,7 +1623,9 @@ class TestExternalDatasetServiceFetchRetrieval:
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory):
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
):
"""Test error when a binding points to an API template outside the dataset tenant."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
@ -1575,7 +1637,9 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory):
def test_fetch_external_knowledge_retrieval_empty_results(
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
):
"""Test retrieval with empty results."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
@ -1598,7 +1662,9 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_with_score_threshold(self, mock_db, mock_process, factory):
def test_fetch_external_knowledge_retrieval_with_score_threshold(
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
):
"""Test retrieval with score threshold enabled."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
@ -1630,7 +1696,9 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
):
"""Test that non-200 status code raises Exception with response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
@ -1665,7 +1733,7 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
self, mock_db, mock_process, factory, status_code, error_message
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory, status_code, error_message
):
"""Test that various error status codes raise exceptions with response text."""
# Arrange
@ -1690,7 +1758,9 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
def test_fetch_external_knowledge_retrieval_empty_response_text(
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
):
"""Test exception with empty response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()

View File

@ -259,7 +259,7 @@ def test_resolve_form_inputs_uses_runtime_select_options(
def test_submit_form_by_token_calls_repository_and_enqueue(
sample_form_record, mock_session_factory, mocker: MockerFixture
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
@ -318,7 +318,7 @@ def test_submit_form_by_token_enqueues_agent_app_resume_for_conversation_form(
def test_submit_form_by_token_skips_enqueue_for_delivery_test(
sample_form_record, mock_session_factory, mocker: MockerFixture
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
@ -343,7 +343,7 @@ def test_submit_form_by_token_skips_enqueue_for_delivery_test(
def test_submit_form_by_token_passes_submission_user_id(
sample_form_record, mock_session_factory, mocker: MockerFixture
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
@ -528,7 +528,7 @@ def test_validate_human_input_submission_rejects_invalid_select_and_file_payload
repo.mark_submitted.assert_not_called()
def test_form_properties(sample_form_record):
def test_form_properties(sample_form_record: HumanInputFormRecord):
form = Form(sample_form_record)
assert form.id == "form-id"
assert form.workflow_run_id == "workflow-run-id"

View File

@ -90,7 +90,7 @@ class TestMessageServicePaginationByFirstId:
return TestMessageServiceFactory()
# Test 01: No user provided
def test_pagination_by_first_id_no_user(self, factory):
def test_pagination_by_first_id_no_user(self, factory: TestMessageServiceFactory):
"""Test pagination returns empty result when no user is provided."""
# Arrange
app = factory.create_app_mock()
@ -111,7 +111,7 @@ class TestMessageServicePaginationByFirstId:
assert result.has_more is False
# Test 02: No conversation_id provided
def test_pagination_by_first_id_no_conversation(self, factory):
def test_pagination_by_first_id_no_conversation(self, factory: TestMessageServiceFactory):
"""Test pagination returns empty result when no conversation_id is provided."""
# Arrange
app = factory.create_app_mock()
@ -137,7 +137,7 @@ class TestMessageServicePaginationByFirstId:
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_without_first_id_desc(
self, mock_conversation_service, mock_db, mock_create_repo, factory
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
):
"""Test basic pagination without first_id in descending order."""
# Arrange
@ -180,7 +180,7 @@ class TestMessageServicePaginationByFirstId:
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_without_first_id_asc(
self, mock_conversation_service, mock_db, mock_create_repo, factory
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
):
"""Test basic pagination without first_id in ascending order."""
# Arrange
@ -222,7 +222,9 @@ class TestMessageServicePaginationByFirstId:
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, mock_create_repo, factory):
def test_pagination_by_first_id_with_first_id(
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
):
"""Test pagination with first_id to get messages before a specific message."""
# Arrange
app = factory.create_app_mock()
@ -265,7 +267,9 @@ class TestMessageServicePaginationByFirstId:
# Test 06: First message not found
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_first_message_not_exists(self, mock_conversation_service, mock_db, factory):
def test_pagination_by_first_id_first_message_not_exists(
self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory
):
"""Test error handling when first_id doesn't exist."""
# Arrange
app = factory.create_app_mock()
@ -290,7 +294,9 @@ class TestMessageServicePaginationByFirstId:
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, mock_create_repo, factory):
def test_pagination_by_first_id_has_more_true(
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
):
"""Test has_more flag is True when results exceed limit."""
# Arrange
app = factory.create_app_mock()
@ -327,7 +333,9 @@ class TestMessageServicePaginationByFirstId:
# Test 08: Empty conversation
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_empty_conversation(self, mock_conversation_service, mock_db, factory):
def test_pagination_by_first_id_empty_conversation(
self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory
):
"""Test pagination with conversation that has no messages."""
# Arrange
app = factory.create_app_mock()
@ -370,7 +378,7 @@ class TestMessageServicePaginationByLastId:
return TestMessageServiceFactory()
# Test 09: No user provided
def test_pagination_by_last_id_no_user(self, factory):
def test_pagination_by_last_id_no_user(self, factory: TestMessageServiceFactory):
"""Test pagination returns empty result when no user is provided."""
# Arrange
app = factory.create_app_mock()
@ -391,7 +399,7 @@ class TestMessageServicePaginationByLastId:
# Test 10: Basic pagination without last_id
@patch("services.message_service.db")
def test_pagination_by_last_id_without_last_id(self, mock_db, factory):
def test_pagination_by_last_id_without_last_id(self, mock_db, factory: TestMessageServiceFactory):
"""Test basic pagination without last_id."""
# Arrange
app = factory.create_app_mock()
@ -422,7 +430,7 @@ class TestMessageServicePaginationByLastId:
# Test 11: Pagination with last_id
@patch("services.message_service.db")
def test_pagination_by_last_id_with_last_id(self, mock_db, factory):
def test_pagination_by_last_id_with_last_id(self, mock_db, factory: TestMessageServiceFactory):
"""Test pagination with last_id to get messages after a specific message."""
# Arrange
app = factory.create_app_mock()
@ -459,7 +467,7 @@ class TestMessageServicePaginationByLastId:
# Test 12: Last message not found
@patch("services.message_service.db")
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory):
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory: TestMessageServiceFactory):
"""Test error handling when last_id doesn't exist."""
# Arrange
app = factory.create_app_mock()
@ -479,7 +487,9 @@ class TestMessageServicePaginationByLastId:
# Test 13: Pagination with conversation_id filter
@patch("services.message_service.ConversationService")
@patch("services.message_service.db")
def test_pagination_by_last_id_with_conversation_filter(self, mock_db, mock_conversation_service, factory):
def test_pagination_by_last_id_with_conversation_filter(
self, mock_db, mock_conversation_service, factory: TestMessageServiceFactory
):
"""Test pagination filtered by conversation_id."""
# Arrange
app = factory.create_app_mock()
@ -515,7 +525,7 @@ class TestMessageServicePaginationByLastId:
# Test 14: Pagination with include_ids filter
@patch("services.message_service.db")
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory):
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory: TestMessageServiceFactory):
"""Test pagination filtered by include_ids."""
# Arrange
app = factory.create_app_mock()
@ -545,7 +555,7 @@ class TestMessageServicePaginationByLastId:
# Test 15: Has_more flag when results exceed limit
@patch("services.message_service.db")
def test_pagination_by_last_id_has_more_true(self, mock_db, factory):
def test_pagination_by_last_id_has_more_true(self, mock_db, factory: TestMessageServiceFactory):
"""Test has_more flag is True when results exceed limit."""
# Arrange
app = factory.create_app_mock()
@ -592,7 +602,7 @@ class TestMessageServiceUtilities:
# Test 17: attach_message_extra_contents with messages
@patch("services.message_service._create_execution_extra_content_repository")
def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory):
def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory: TestMessageServiceFactory):
"""Test attach_message_extra_contents correctly attaches content."""
# Arrange
messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")]
@ -618,7 +628,9 @@ class TestMessageServiceUtilities:
# Test 18: attach_message_extra_contents with index out of bounds
@patch("services.message_service._create_execution_extra_content_repository")
def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory):
def test_attach_message_extra_contents_index_out_of_bounds(
self, mock_create_repo, factory: TestMessageServiceFactory
):
"""Test attach_message_extra_contents handles missing content lists."""
# Arrange
messages = [factory.create_message_mock(message_id="msg-1")]
@ -659,7 +671,7 @@ class TestMessageServiceGetMessage:
# Test 20: get_message success for EndUser
@patch("services.message_service.db")
def test_get_message_end_user_success(self, mock_db, factory):
def test_get_message_end_user_success(self, mock_db, factory: TestMessageServiceFactory):
"""Test get_message returns message for EndUser."""
# Arrange
app = factory.create_app_mock()
@ -676,7 +688,7 @@ class TestMessageServiceGetMessage:
# Test 21: get_message success for Account (Admin)
@patch("services.message_service.db")
def test_get_message_account_success(self, mock_db, factory):
def test_get_message_account_success(self, mock_db, factory: TestMessageServiceFactory):
"""Test get_message returns message for Account."""
# Arrange
from models import Account
@ -696,7 +708,7 @@ class TestMessageServiceGetMessage:
# Test 22: get_message not found
@patch("services.message_service.db")
def test_get_message_not_found(self, mock_db, factory):
def test_get_message_not_found(self, mock_db, factory: TestMessageServiceFactory):
"""Test get_message raises MessageNotExistsError when not found."""
# Arrange
app = factory.create_app_mock()
@ -720,7 +732,7 @@ class TestMessageServiceFeedback:
# Test 23: create_feedback - new feedback for EndUser
@patch("services.message_service.db")
@patch.object(MessageService, "get_message")
def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory):
def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
"""Test creating new feedback for an end user."""
# Arrange
app = factory.create_app_mock()
@ -748,7 +760,7 @@ class TestMessageServiceFeedback:
# Test 24: create_feedback - update feedback for Account
@patch("services.message_service.db")
@patch.object(MessageService, "get_message")
def test_create_feedback_update_account(self, mock_get_message, mock_db, factory):
def test_create_feedback_update_account(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
"""Test updating existing feedback for an account."""
# Arrange
from models import Account, MessageFeedback
@ -779,7 +791,7 @@ class TestMessageServiceFeedback:
# Test 25: create_feedback - delete feedback (rating is None)
@patch("services.message_service.db")
@patch.object(MessageService, "get_message")
def test_create_feedback_delete(self, mock_get_message, mock_db, factory):
def test_create_feedback_delete(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
"""Test deleting feedback by passing rating=None."""
# Arrange
app = factory.create_app_mock()
@ -805,7 +817,7 @@ class TestMessageServiceFeedback:
# Test 26: get_all_messages_feedbacks
@patch("services.message_service.db")
def test_get_all_messages_feedbacks(self, mock_db, factory):
def test_get_all_messages_feedbacks(self, mock_db, factory: TestMessageServiceFactory):
"""Test get_all_messages_feedbacks returns list of dicts."""
# Arrange
app = factory.create_app_mock()
@ -830,7 +842,7 @@ class TestMessageServiceSuggestedQuestions:
return TestMessageServiceFactory()
# Test 27: get_suggested_questions_after_answer - user is None
def test_get_suggested_questions_user_none(self, factory):
def test_get_suggested_questions_user_none(self, factory: TestMessageServiceFactory):
app = factory.create_app_mock()
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.get_suggested_questions_after_answer(
@ -856,7 +868,7 @@ class TestMessageServiceSuggestedQuestions:
mock_config_manager,
mock_workflow_service,
mock_model_manager,
factory,
factory: TestMessageServiceFactory,
):
"""Test successful suggested questions generation in Advanced Chat mode."""
from core.app.entities.app_invoke_entities import InvokeFrom
@ -896,14 +908,14 @@ class TestMessageServiceSuggestedQuestions:
@patch("services.message_service.ConversationService")
def test_get_suggested_questions_chat_app_success(
self,
mock_conversation_service,
mock_get_message,
mock_trace_manager,
mock_llm_gen,
mock_memory,
mock_model_manager,
mock_db,
factory,
mock_conversation_service: MagicMock,
mock_get_message: MagicMock,
mock_trace_manager: MagicMock,
mock_llm_gen: MagicMock,
mock_memory: MagicMock,
mock_model_manager: MagicMock,
mock_db: MagicMock,
factory: TestMessageServiceFactory,
):
"""Test successful suggested questions generation in basic Chat mode."""
# Arrange
@ -942,14 +954,14 @@ class TestMessageServiceSuggestedQuestions:
@patch("services.message_service.ConversationService")
def test_get_suggested_questions_chat_app_uses_frontend_model_and_prompt(
self,
mock_conversation_service,
mock_get_message,
mock_trace_manager,
mock_llm_gen,
mock_memory,
mock_model_manager,
mock_db,
factory,
mock_conversation_service: MagicMock,
mock_get_message: MagicMock,
mock_trace_manager: MagicMock,
mock_llm_gen: MagicMock,
mock_memory: MagicMock,
mock_model_manager: MagicMock,
mock_db: MagicMock,
factory: TestMessageServiceFactory,
):
"""Test suggested question generation uses frontend configured model and prompt."""
from core.app.entities.app_invoke_entities import InvokeFrom
@ -1015,14 +1027,14 @@ class TestMessageServiceSuggestedQuestions:
@patch("services.message_service.ConversationService")
def test_get_suggested_questions_chat_app_invalid_frontend_model_fallback_to_default(
self,
mock_conversation_service,
mock_get_message,
mock_trace_manager,
mock_llm_gen,
mock_memory,
mock_model_manager,
mock_db,
factory,
mock_conversation_service: MagicMock,
mock_get_message: MagicMock,
mock_trace_manager: MagicMock,
mock_llm_gen: MagicMock,
mock_memory: MagicMock,
mock_model_manager: MagicMock,
mock_db: MagicMock,
factory: TestMessageServiceFactory,
):
"""Test invalid frontend configured model falls back to tenant default model."""
app = factory.create_app_mock(mode=AppMode.CHAT)
@ -1066,14 +1078,14 @@ class TestMessageServiceSuggestedQuestions:
@patch("services.message_service.ConversationService")
def test_get_suggested_questions_chat_app_uses_compatible_override_model_config(
self,
mock_conversation_service,
mock_get_message,
mock_trace_manager,
mock_llm_gen,
mock_memory,
mock_model_manager,
mock_db,
factory,
mock_conversation_service: MagicMock,
mock_get_message: MagicMock,
mock_trace_manager: MagicMock,
mock_llm_gen: MagicMock,
mock_memory: MagicMock,
mock_model_manager: MagicMock,
mock_db: MagicMock,
factory: TestMessageServiceFactory,
):
"""Test legacy override configs are normalized before suggested questions reads them."""
app = factory.create_app_mock(mode=AppMode.CHAT)
@ -1174,7 +1186,12 @@ class TestMessageServiceSuggestedQuestions:
@patch.object(MessageService, "get_message")
@patch("services.message_service.ConversationService")
def test_get_suggested_questions_disabled_error(
self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory
self,
mock_conversation_service,
mock_get_message,
mock_config_manager,
mock_workflow_service,
factory: TestMessageServiceFactory,
):
"""Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled."""
# Arrange

View File

@ -1,6 +1,6 @@
import json
import logging
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -130,7 +130,7 @@ class TestRagPipelineTaskProxy:
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
def test_features_property(self, mock_feature_service: MagicMock):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
@ -149,7 +149,7 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
def test_upload_invoke_entities(self, mock_db: MagicMock, mock_file_service_class: MagicMock):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -181,7 +181,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
def test_upload_invoke_entities_with_multiple_entities(
self, mock_db: MagicMock, mock_file_service_class: MagicMock
):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
@ -209,7 +211,7 @@ class TestRagPipelineTaskProxy:
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_direct_queue(self, mock_task):
def test_send_to_direct_queue(self, mock_task: MagicMock):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -229,7 +231,7 @@ class TestRagPipelineTaskProxy:
)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task: MagicMock):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -248,7 +250,7 @@ class TestRagPipelineTaskProxy:
mock_task.delay.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
def test_send_to_tenant_queue_without_task_key(self, mock_task: MagicMock):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -271,7 +273,7 @@ class TestRagPipelineTaskProxy:
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_default_tenant_queue(self, mock_task):
def test_send_to_default_tenant_queue(self, mock_task: MagicMock):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -285,7 +287,7 @@ class TestRagPipelineTaskProxy:
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_tenant_queue(self, mock_task):
def test_send_to_priority_tenant_queue(self, mock_task: MagicMock):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -299,7 +301,7 @@ class TestRagPipelineTaskProxy:
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_direct_queue(self, mock_task):
def test_send_to_priority_direct_queue(self, mock_task: MagicMock):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -315,7 +317,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
def test_dispatch_with_billing_enabled_sandbox_plan(
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
@ -365,7 +369,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
def test_dispatch_with_billing_disabled(
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
@ -386,7 +392,7 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
def test_dispatch_with_empty_upload_file_id(self, mock_db: MagicMock, mock_file_service_class: MagicMock):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
@ -404,7 +410,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
def test_dispatch_edge_case_empty_plan(
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
@ -426,7 +434,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
def test_dispatch_edge_case_none_plan(
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
@ -448,7 +458,9 @@ class TestRagPipelineTaskProxy:
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
def test_delay_method(
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(

View File

@ -73,7 +73,7 @@ def test_import_snippet_rejects_invalid_yaml_url_scheme() -> None:
assert result.error == "Invalid URL scheme, only http and https are allowed"
def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) -> None:
def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch: pytest.MonkeyPatch) -> None:
service = SnippetDslService(session=SimpleNamespace())
monkeypatch.setattr(
"services.snippet_dsl_service.ssrf_proxy.get",
@ -90,7 +90,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) ->
assert result.error == "Failed to fetch YAML from URL: 404"
def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None:
def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch: pytest.MonkeyPatch) -> None:
service = SnippetDslService(session=SimpleNamespace())
monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3)
monkeypatch.setattr(
@ -108,7 +108,7 @@ def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None:
assert "YAML content size exceeds maximum limit" in result.error
def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) -> None:
def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch: pytest.MonkeyPatch) -> None:
service = SnippetDslService(session=SimpleNamespace())
monkeypatch.setattr(
"services.snippet_dsl_service.ssrf_proxy.get",
@ -125,7 +125,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) -
assert result.error == "Failed to fetch YAML from URL: network down"
def test_import_snippet_rejects_oversized_yaml_content(monkeypatch) -> None:
def test_import_snippet_rejects_oversized_yaml_content(monkeypatch: pytest.MonkeyPatch) -> None:
service = SnippetDslService(session=SimpleNamespace())
monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3)

View File

@ -529,7 +529,7 @@ def test_delete_snippet_removes_related_records() -> None:
session.delete.assert_called_once_with(snippet)
def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> None:
def test_delete_draft_variable_files_removes_storage_objects(monkeypatch: pytest.MonkeyPatch) -> None:
from extensions.ext_storage import storage
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
@ -554,7 +554,7 @@ def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> Non
assert "workflow_draft_variable_files" in executed_sql
def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch) -> None:
def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch: pytest.MonkeyPatch) -> None:
from configs import dify_config
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")

View File

@ -30,10 +30,10 @@ class TestWorkflowGeneratorService:
@patch("services.workflow_generator_service.format_tool_catalogue")
def test_forwards_model_instance_and_catalogue_text_to_generator(
self,
mock_format_catalogue,
mock_build_catalogue,
mock_model_manager,
mock_workflow_generator,
mock_format_catalogue: MagicMock,
mock_build_catalogue: MagicMock,
mock_model_manager: MagicMock,
mock_workflow_generator: MagicMock,
):
"""Happy path: model_instance + catalogue text + payload flow through."""
# Arrange
@ -110,10 +110,10 @@ class TestWorkflowGeneratorService:
@patch("services.workflow_generator_service.format_tool_catalogue")
def test_defaults_ideal_output_to_empty_string(
self,
mock_format_catalogue,
mock_build_catalogue,
mock_model_manager,
mock_workflow_generator,
mock_format_catalogue: MagicMock,
mock_build_catalogue: MagicMock,
mock_model_manager: MagicMock,
mock_workflow_generator: MagicMock,
):
"""Callers can omit ideal_output; the runner should still receive ""."""
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()
@ -142,10 +142,10 @@ class TestWorkflowGeneratorService:
@patch("services.workflow_generator_service.format_tool_catalogue")
def test_forwards_current_graph_for_refine(
self,
mock_format_catalogue,
mock_build_catalogue,
mock_model_manager,
mock_workflow_generator,
mock_format_catalogue: MagicMock,
mock_build_catalogue: MagicMock,
mock_model_manager: MagicMock,
mock_workflow_generator: MagicMock,
):
"""The cmd+k `/refine` path passes the existing draft graph through to the runner."""
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()
@ -175,10 +175,10 @@ class TestWorkflowGeneratorService:
@patch("services.workflow_generator_service.format_tool_catalogue")
def test_defaults_current_graph_to_none_for_create(
self,
mock_format_catalogue,
mock_build_catalogue,
mock_model_manager,
mock_workflow_generator,
mock_format_catalogue: MagicMock,
mock_build_catalogue: MagicMock,
mock_model_manager: MagicMock,
mock_workflow_generator: MagicMock,
):
"""Omitting current_graph (the `/create` path) forwards None to the runner."""
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()

View File

@ -1,3 +1,4 @@
from pathlib import Path
from textwrap import dedent
import pytest
@ -6,7 +7,7 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit
@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
def prepare_example_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(
dedent(
@ -25,7 +26,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
@pytest.fixture
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
def prepare_empty_commented_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(
dedent(

View File

@ -1,3 +1,4 @@
from pathlib import Path
from textwrap import dedent
import pytest
@ -11,7 +12,7 @@ NON_EXISTING_YAML_FILE = "non_existing_file.yaml"
@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
def prepare_example_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text(
@ -34,7 +35,7 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
@pytest.fixture
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
def prepare_invalid_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text(