mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
chore: add Type to test (#37191)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9eca75c7fc
commit
fae607e2fe
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload, over
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Request, Response, stream_with_context
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter, with_config
|
||||
from pydantic.functional_validators import AfterValidator
|
||||
@ -167,19 +167,6 @@ def build_avatar_url(avatar: str | None) -> str | None:
|
||||
return file_helpers.get_signed_file_url(avatar)
|
||||
|
||||
|
||||
class AvatarUrlField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
from models import Account
|
||||
|
||||
if isinstance(obj, Account) and obj.avatar is not None:
|
||||
return build_avatar_url(obj.avatar)
|
||||
return None
|
||||
|
||||
|
||||
class TimestampField(fields.Raw):
|
||||
@override
|
||||
def schema(self) -> dict[str, object]:
|
||||
@ -397,7 +384,7 @@ def generate_string(n):
|
||||
return result
|
||||
|
||||
|
||||
def extract_remote_ip(request) -> str:
|
||||
def extract_remote_ip(request: Request) -> str:
|
||||
if request.headers.get("CF-Connecting-IP"):
|
||||
return cast(str, request.headers.get("CF-Connecting-IP"))
|
||||
elif request.headers.getlist("X-Forwarded-For"):
|
||||
|
||||
@ -42,7 +42,9 @@ def trace_client_factory():
|
||||
class TestTraceClient:
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
|
||||
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
|
||||
def test_init(
|
||||
self, mock_gethostname: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
|
||||
):
|
||||
mock_gethostname.return_value = "test-host"
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
@ -56,7 +58,7 @@ class TestTraceClient:
|
||||
assert client.done is True
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export(self, mock_exporter_class, trace_client_factory):
|
||||
def test_export(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
spans = [MagicMock(spec=ReadableSpan)]
|
||||
@ -65,7 +67,9 @@ class TestTraceClient:
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
def test_api_check_success(
|
||||
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_head.return_value = mock_response
|
||||
@ -75,7 +79,9 @@ class TestTraceClient:
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
def test_api_check_failure_status(
|
||||
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_head.return_value = mock_response
|
||||
@ -85,7 +91,9 @@ class TestTraceClient:
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
def test_api_check_exception(
|
||||
self, mock_exporter_class: MagicMock, mock_head: MagicMock, trace_client_factory: type[TraceClient]
|
||||
):
|
||||
mock_head.side_effect = httpx.RequestError("Connection error")
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@ -93,12 +101,12 @@ class TestTraceClient:
|
||||
client.api_check()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
|
||||
def test_get_project_url(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_add_span(self, mock_exporter_class, trace_client_factory):
|
||||
def test_add_span(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
@ -135,7 +143,9 @@ class TestTraceClient:
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
|
||||
def test_add_span_queue_full(
|
||||
self, mock_logger: MagicMock, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]
|
||||
):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
span_data = SpanData(
|
||||
@ -159,7 +169,7 @@ class TestTraceClient:
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
|
||||
def test_export_batch_error(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
|
||||
@ -172,13 +182,13 @@ class TestTraceClient:
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
|
||||
def test_worker_loop(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
# We need to test the wait timeout in _worker
|
||||
# But _worker runs in a thread. Let's mock condition.wait.
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
schedule_delay_sec=0.1,
|
||||
schedule_delay_sec=1,
|
||||
)
|
||||
|
||||
with patch.object(client.condition, "wait") as mock_wait:
|
||||
@ -189,7 +199,7 @@ class TestTraceClient:
|
||||
assert mock_wait.called or client.done
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
|
||||
def test_shutdown_flushes(self, mock_exporter_class: MagicMock, trace_client_factory: type[TraceClient]):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
|
||||
@ -456,7 +456,7 @@ class TestPhoenixParentSpanBridgeHelpers:
|
||||
assert error.parent_node_execution_id == "outer-node-execution-1"
|
||||
assert "outer-node-execution-1" in str(error)
|
||||
|
||||
def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch):
|
||||
def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get.return_value = '{"tracestate": "vendor=value"}'
|
||||
monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis)
|
||||
|
||||
@ -656,7 +656,7 @@ def _patch_workflow_trace_deps(monkeypatch, trace_instance):
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
|
||||
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Chatflow with external trace_id: LangSmith trace_id must be message_id, not external."""
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id="msg-abc",
|
||||
@ -677,7 +677,7 @@ def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypa
|
||||
assert trace_info.metadata.get("external_trace_id") == "external-999"
|
||||
|
||||
|
||||
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id."""
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id=None,
|
||||
|
||||
@ -13,7 +13,9 @@ from extensions.ext_database import db
|
||||
from models import App
|
||||
|
||||
|
||||
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def test_run_chat_dispatches_to_chat_handler(
|
||||
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
captured = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
@ -78,7 +80,9 @@ def app_with_mode(flask_app: Flask, workspace_account):
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def test_run_chat_without_query_returns_422(
|
||||
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
@ -89,7 +93,9 @@ def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_wor
|
||||
assert b"query_required_for_chat" in res.data
|
||||
|
||||
|
||||
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
def test_run_completion_dispatches_to_completion_handler(
|
||||
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app = app_with_mode("completion")
|
||||
|
||||
captured: dict = {}
|
||||
@ -119,7 +125,9 @@ def test_run_completion_dispatches_to_completion_handler(flask_app, account_toke
|
||||
assert captured["mode"] == "completion"
|
||||
|
||||
|
||||
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
def test_run_workflow_with_query_returns_422(
|
||||
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app = app_with_mode("workflow")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
@ -131,7 +139,9 @@ def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_
|
||||
assert b"query_not_supported_for_workflow" in res.data
|
||||
|
||||
|
||||
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
def test_run_workflow_no_query_dispatches_to_workflow_handler(
|
||||
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app = app_with_mode("workflow")
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
@ -154,7 +164,9 @@ def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account
|
||||
assert body["workflow_run_id"] == "wfr"
|
||||
|
||||
|
||||
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
def test_run_unsupported_mode_returns_422(
|
||||
flask_app: Flask, account_token, app_with_mode, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app = app_with_mode("channel")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
@ -166,7 +178,7 @@ def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mod
|
||||
assert b"mode_not_runnable" in res.data
|
||||
|
||||
|
||||
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||
def test_run_without_bearer_returns_401(flask_app: Flask, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
@ -175,7 +187,9 @@ def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def test_run_with_insufficient_scope_returns_403(
|
||||
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Stub the authenticator to return an AuthContext with empty scopes."""
|
||||
from libs import oauth_bearer
|
||||
|
||||
@ -198,7 +212,7 @@ def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_i
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||
def test_run_with_unknown_app_returns_404(flask_app: Flask, account_token):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{uuid.uuid4()}/run",
|
||||
@ -208,7 +222,9 @@ def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def test_run_streaming_returns_event_stream(
|
||||
flask_app: Flask, account_token, app_in_workspace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
def _stream() -> Generator[str, None, None]:
|
||||
yield 'event: message\ndata: {"x": 1}\n\n'
|
||||
|
||||
@ -228,7 +244,7 @@ def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_wor
|
||||
assert b"event: message" in res.data
|
||||
|
||||
|
||||
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
|
||||
def test_run_without_inputs_returns_422(flask_app: Flask, account_token, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
|
||||
@ -4,6 +4,8 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
@ -91,7 +93,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
return node
|
||||
|
||||
|
||||
def _mock_db_session_close(monkeypatch) -> None:
|
||||
def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(db.session, "close", MagicMock())
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_runtime import DifyPromptMessageSerializer
|
||||
@ -83,11 +85,11 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
return node
|
||||
|
||||
|
||||
def _mock_db_session_close(monkeypatch) -> None:
|
||||
def _mock_db_session_close(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(db.session, "close", MagicMock())
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test function calling for parameter extractor.
|
||||
"""
|
||||
@ -128,7 +130,7 @@ def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
|
||||
def test_instructions(setup_model_mock, monkeypatch):
|
||||
def test_instructions(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
@ -178,7 +180,7 @@ def test_instructions(setup_model_mock, monkeypatch):
|
||||
assert "what's the weather in SF" in prompt.get("text")
|
||||
|
||||
|
||||
def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
def test_chat_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
@ -229,7 +231,7 @@ def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
assert '<structure>\n{"type": "object"' in prompt.get("text")
|
||||
|
||||
|
||||
def test_completion_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
def test_completion_parameter_extractor(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test completion parameter extractor.
|
||||
"""
|
||||
@ -354,7 +356,7 @@ def test_extract_json_from_tool_call():
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test chat parameter extractor with memory.
|
||||
"""
|
||||
|
||||
@ -517,11 +517,11 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
mock_tenant_service: MagicMock,
|
||||
mock_account_service: MagicMock,
|
||||
mock_register_service: MagicMock,
|
||||
mock_feature_service: MagicMock,
|
||||
mock_get_account: MagicMock,
|
||||
app: Flask,
|
||||
user_info: OAuthUserInfo,
|
||||
mock_account,
|
||||
@ -562,11 +562,11 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
mock_tenant_service: MagicMock,
|
||||
mock_account_service: MagicMock,
|
||||
mock_register_service: MagicMock,
|
||||
mock_feature_service: MagicMock,
|
||||
mock_get_account: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
@ -593,11 +593,11 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_browser_timezone(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
mock_tenant_service: MagicMock,
|
||||
mock_account_service: MagicMock,
|
||||
mock_register_service: MagicMock,
|
||||
mock_feature_service: MagicMock,
|
||||
mock_get_account: MagicMock,
|
||||
app: Flask,
|
||||
user_info: OAuthUserInfo,
|
||||
):
|
||||
@ -624,11 +624,11 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_state_language(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
mock_tenant_service: MagicMock,
|
||||
mock_account_service: MagicMock,
|
||||
mock_register_service: MagicMock,
|
||||
mock_feature_service: MagicMock,
|
||||
mock_get_account: MagicMock,
|
||||
app: Flask,
|
||||
user_info: OAuthUserInfo,
|
||||
):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -83,8 +83,8 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_success(
|
||||
self,
|
||||
mock_encrypter,
|
||||
mock_factory,
|
||||
mock_encrypter: MagicMock,
|
||||
mock_factory: MagicMock,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
@ -107,7 +107,12 @@ class TestApiKeyAuthService:
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_validation_failed(
|
||||
self, mock_factory, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, mock_args
|
||||
self,
|
||||
mock_factory: MagicMock,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
mock_args,
|
||||
):
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
@ -123,8 +128,8 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encrypts_api_key(
|
||||
self,
|
||||
mock_encrypter,
|
||||
mock_factory,
|
||||
mock_encrypter: MagicMock,
|
||||
mock_factory: MagicMock,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
@ -289,7 +294,7 @@ class TestApiKeyAuthService:
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args):
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args):
|
||||
mock_factory.side_effect = Exception("Factory error")
|
||||
with pytest.raises(Exception, match="Factory error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
|
||||
@ -47,8 +47,8 @@ class TestGetDynamicSelectOptionsTool:
|
||||
@patch("services.plugin.plugin_parameter_service.ToolManager")
|
||||
def test_fetches_credentials_with_credential_id(
|
||||
self,
|
||||
mock_tool_mgr,
|
||||
mock_encrypter_fn,
|
||||
mock_tool_mgr: MagicMock,
|
||||
mock_encrypter_fn: MagicMock,
|
||||
mock_client_cls,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
@ -90,8 +90,8 @@ class TestGetDynamicSelectOptionsTool:
|
||||
@patch("services.plugin.plugin_parameter_service.ToolManager")
|
||||
def test_raises_when_tool_provider_not_found(
|
||||
self,
|
||||
mock_tool_mgr,
|
||||
mock_encrypter_fn,
|
||||
mock_tool_mgr: MagicMock,
|
||||
mock_encrypter_fn: MagicMock,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
|
||||
@ -258,7 +258,7 @@ class TestUpgradePluginWithMarketplace:
|
||||
class TestUpgradePluginWithGithub:
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
|
||||
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls: MagicMock, mock_fs: MagicMock):
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
installer = mock_installer_cls.return_value
|
||||
installer.upgrade_plugin.return_value = MagicMock()
|
||||
@ -273,7 +273,7 @@ class TestUpgradePluginWithGithub:
|
||||
class TestUploadPkg:
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
|
||||
def test_runs_permission_and_scope_checks(self, mock_installer_cls: MagicMock, mock_fs: MagicMock):
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
upload_resp = MagicMock()
|
||||
upload_resp.verification = None
|
||||
@ -318,7 +318,7 @@ class TestInstallFromMarketplacePkg:
|
||||
@patch("core.plugin.plugin_service.FeatureService")
|
||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
||||
@patch("core.plugin.plugin_service.dify_config")
|
||||
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
|
||||
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls: MagicMock, mock_fs: MagicMock):
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_fs.get_system_features.return_value = _make_features()
|
||||
installer = mock_installer_cls.return_value
|
||||
|
||||
@ -5,6 +5,8 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_generate_swagger_markdown_docs_module():
|
||||
api_dir = Path(__file__).resolve().parents[3]
|
||||
@ -20,7 +22,9 @@ def _load_generate_swagger_markdown_docs_module():
|
||||
return module
|
||||
|
||||
|
||||
def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(tmp_path, monkeypatch):
|
||||
def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
openapi_dir = tmp_path / "openapi"
|
||||
markdown_dir = tmp_path / "markdown"
|
||||
@ -69,7 +73,9 @@ def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_con
|
||||
assert "FastOpenAPI Preview" not in (markdown_dir / "service-openapi.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(tmp_path, monkeypatch):
|
||||
def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
swagger_dir = tmp_path / "swagger"
|
||||
markdown_dir = tmp_path / "markdown"
|
||||
@ -105,7 +111,7 @@ def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagg
|
||||
assert not list(swagger_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path):
|
||||
def test_patch_union_schema_markdown_fills_converter_blank_schema_types(tmp_path: Path):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
spec_path = tmp_path / "console-openapi.json"
|
||||
spec_path.write_text(
|
||||
@ -239,7 +245,7 @@ def test_patch_union_schema_markdown_ignores_specs_without_schemas(tmp_path):
|
||||
assert module._patch_union_schema_markdown("unchanged", spec_path) == "unchanged"
|
||||
|
||||
|
||||
def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
|
||||
def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path: Path):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
spec_path = tmp_path / "console-openapi.json"
|
||||
spec_path.write_text(
|
||||
@ -285,7 +291,7 @@ def test_patch_union_schema_markdown_ignores_unrenderable_shapes(tmp_path):
|
||||
assert module._patch_union_schema_markdown("#### BrokenUnion\n", spec_path) == "#### BrokenUnion\n"
|
||||
|
||||
|
||||
def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path, monkeypatch):
|
||||
def test_convert_spec_to_markdown_patches_generated_union_tables(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
spec_path = tmp_path / "console-openapi.json"
|
||||
output_path = tmp_path / "console-openapi.md"
|
||||
|
||||
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_chat_messages_query_valid(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
@ -17,14 +18,14 @@ def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_chat_messages_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery with defaults."""
|
||||
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
|
||||
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_chat_messages_query_empty_first_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery converts empty first_id to None."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
@ -33,7 +34,7 @@ def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch
|
||||
assert query.first_id is None
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_message_feedback_payload_valid_like(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with like rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
@ -44,7 +45,7 @@ def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatc
|
||||
assert payload.content == "Good answer"
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_message_feedback_payload_valid_dislike(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with dislike rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
@ -53,69 +54,69 @@ def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyP
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
|
||||
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_message_feedback_payload_no_rating(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload without rating."""
|
||||
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.rating is None
|
||||
|
||||
|
||||
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_defaults(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with default format."""
|
||||
query = message_module.FeedbackExportQuery()
|
||||
assert query.format == "csv"
|
||||
assert query.from_source is None
|
||||
|
||||
|
||||
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_json_format(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with JSON format."""
|
||||
query = message_module.FeedbackExportQuery(format="json")
|
||||
assert query.format == "json"
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_has_comment_true(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as true string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="true")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_has_comment_false(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as false string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="false")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_has_comment_1(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 1."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="1")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_has_comment_0(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 0."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="0")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_feedback_export_query_rating_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with rating filter."""
|
||||
query = message_module.FeedbackExportQuery(rating="like")
|
||||
assert query.rating == "like"
|
||||
|
||||
|
||||
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_annotation_count_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test AnnotationCountResponse creation."""
|
||||
response = message_module.AnnotationCountResponse(count=10)
|
||||
assert response.count == 10
|
||||
|
||||
|
||||
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_suggested_questions_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test SuggestedQuestionsResponse creation."""
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
|
||||
|
||||
def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_message_detail_response_normalizes_aliases_and_timestamp(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageDetailResponse normalizes alias fields and datetime timestamps."""
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
response = message_module.MessageDetailResponse.model_validate(
|
||||
|
||||
@ -5,6 +5,7 @@ from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
@ -38,7 +39,7 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_message_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -52,7 +53,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_conversation_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -66,7 +67,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_token_cost_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -84,7 +85,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
|
||||
assert data["data"][0]["total_price"] == 0.25
|
||||
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_terminals_statistic_returns_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -98,7 +99,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
@ -107,7 +108,7 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_message_statistic_with_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -123,7 +124,7 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
|
||||
method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_message_statistic_multiple_rows(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -142,7 +143,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_message_statistic_empty_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -155,7 +156,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_conversation_statistic_with_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -174,7 +175,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_daily_token_cost_with_multiple_currencies(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
|
||||
from models.account import Account, TenantAccountRole
|
||||
@ -117,7 +118,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
assert response["has_more"] is False
|
||||
|
||||
|
||||
def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_rag_pipeline_workflow_patch_serializes_response_model(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow(marked_name="Updated release")
|
||||
|
||||
class _SessionContext:
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.snippets import snippet_workflow as snippet_workflow_module
|
||||
@ -52,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app):
|
||||
view()
|
||||
|
||||
|
||||
def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_snippet_injects_resolved_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = _snippet()
|
||||
|
||||
@snippet_workflow_module.get_snippet
|
||||
@ -72,7 +73,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat
|
||||
assert result is snippet
|
||||
|
||||
|
||||
def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_snippet_raises_not_found_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@snippet_workflow_module.get_snippet
|
||||
def view(**kwargs):
|
||||
return kwargs
|
||||
@ -89,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
|
||||
view(snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_draft_workflow_get_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
@ -105,7 +106,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
|
||||
handler(api, snippet=snippet)
|
||||
|
||||
|
||||
def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_draft_workflow_post_returns_400_for_invalid_graph(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
sync_draft_workflow = Mock(side_effect=ValueError("invalid graph"))
|
||||
@ -145,7 +146,7 @@ def test_published_workflow_get_returns_none_when_not_published(app) -> None:
|
||||
assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None
|
||||
|
||||
|
||||
def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_published_workflow_post_returns_400_when_publish_fails(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
merged_snippet = _snippet()
|
||||
@ -180,7 +181,7 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
|
||||
session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_default_block_configs_delegates_to_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
get_default_block_configs = Mock(return_value=[{"type": "llm"}])
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
@ -198,7 +199,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon
|
||||
get_default_block_configs.assert_called_once()
|
||||
|
||||
|
||||
def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_restore_published_snippet_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="restored-hash",
|
||||
updated_at=None,
|
||||
@ -226,7 +227,7 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
|
||||
assert response["hash"] == "restored-hash"
|
||||
|
||||
|
||||
def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_restore_published_snippet_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
|
||||
@ -311,7 +312,7 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra
|
||||
assert exc.value.description == "invalid snippet workflow graph"
|
||||
|
||||
|
||||
def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_run_detail_raises_not_found_when_run_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
@ -327,7 +328,9 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
|
||||
handler(api, snippet=snippet, run_id="run-1")
|
||||
|
||||
|
||||
def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_draft_node_last_run_raises_not_found_when_execution_missing(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
snippet = _snippet()
|
||||
draft_workflow = SimpleNamespace(id="workflow-1")
|
||||
monkeypatch.setattr(
|
||||
@ -347,7 +350,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey
|
||||
handler(api, snippet=snippet, node_id="llm-1")
|
||||
|
||||
|
||||
def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_task_stop_uses_queue_flag_and_graph_command(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
set_stop_flag = Mock()
|
||||
send_stop_command = Mock()
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -27,7 +27,7 @@ def _make_account() -> Account:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_snippet_service_factory(monkeypatch):
|
||||
def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch):
|
||||
def factory():
|
||||
service_factory = module.SnippetService
|
||||
if isinstance(service_factory, type):
|
||||
@ -64,7 +64,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable(
|
||||
module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1")
|
||||
|
||||
|
||||
def test_conversation_variables_returns_empty_list(app):
|
||||
def test_conversation_variables_returns_empty_list(app: Flask):
|
||||
api = module.SnippetConversationVariableCollectionApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
@ -74,7 +74,7 @@ def test_conversation_variables_returns_empty_list(app):
|
||||
assert result == WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
def test_system_variables_returns_empty_list(app):
|
||||
def test_system_variables_returns_empty_list(app: Flask):
|
||||
api = module.SnippetSystemVariableCollectionApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import init_validate
|
||||
from controllers.console.error import AlreadySetupError, InitValidateFailedError
|
||||
@ -35,7 +36,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.status == "not_started"
|
||||
|
||||
|
||||
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
|
||||
app.secret_key = "test-secret"
|
||||
@ -45,7 +46,7 @@ def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPat
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
|
||||
|
||||
|
||||
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
@ -57,7 +58,7 @@ def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPa
|
||||
assert init_validate.session.get("is_init_validated") is False
|
||||
|
||||
|
||||
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
@ -74,7 +75,7 @@ def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatc
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_init_validate_status_validated_session(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
@ -84,7 +85,7 @@ def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.Mon
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_init_validate_status_setup_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
|
||||
@ -96,7 +97,7 @@ def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPa
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_init_validate_status_not_validated(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
@ -82,7 +83,7 @@ def _mock_upload_dependencies(
|
||||
return file_service_cls, current_user
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_remote_file_info_uses_head_when_successful(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
@ -103,7 +104,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
|
||||
make_request.assert_called_once_with("HEAD", decoded_url)
|
||||
|
||||
|
||||
def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_remote_file_info_preserves_unencoded_target_query(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = unwrap(api.get)
|
||||
target_url = "http://example.com/api/aiagent/httpview/txt"
|
||||
@ -124,7 +125,9 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
|
||||
make_request.assert_called_once_with("HEAD", f"{target_url}?{query}")
|
||||
|
||||
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
@ -147,7 +150,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
|
||||
assert make_request.call_args_list[1].kwargs == {"timeout": 3}
|
||||
|
||||
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
@ -220,7 +223,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
@ -238,7 +241,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
@ -252,7 +255,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_rejects_oversized_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
@ -267,7 +270,9 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
@ -282,7 +287,9 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
|
||||
handler(api, current_user)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
@ -102,10 +102,10 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
mock_get_change_data: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
@ -143,10 +143,10 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
mock_get_change_data: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
@ -178,10 +178,10 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
mock_get_change_data: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
@ -603,7 +603,7 @@ class TestRateLimiting:
|
||||
|
||||
@patch("controllers.console.wraps.redis_client")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
|
||||
def test_should_allow_requests_within_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock):
|
||||
"""Test that requests within rate limit are allowed"""
|
||||
# Arrange
|
||||
mock_rate_limit = MagicMock()
|
||||
@ -631,7 +631,7 @@ class TestRateLimiting:
|
||||
|
||||
@patch("controllers.console.wraps.redis_client")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
|
||||
def test_should_reject_requests_over_rate_limit(self, mock_db: MagicMock, mock_redis: MagicMock):
|
||||
"""Test that requests over rate limit are rejected and logged"""
|
||||
# Arrange
|
||||
app = create_app_with_login()
|
||||
@ -720,7 +720,7 @@ class TestSystemSetup:
|
||||
"""Test system setup decorator"""
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
def test_should_allow_when_setup_complete(self, mock_db: MagicMock):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
|
||||
@ -737,7 +737,7 @@ class TestSystemSetup:
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.wraps.os.environ.get")
|
||||
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
|
||||
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db: MagicMock):
|
||||
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
@ -754,7 +754,7 @@ class TestSystemSetup:
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.wraps.os.environ.get")
|
||||
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
|
||||
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db: MagicMock):
|
||||
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
|
||||
@ -3,6 +3,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.workspace import snippets as snippets_module
|
||||
@ -69,7 +70,7 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values():
|
||||
}
|
||||
|
||||
|
||||
def test_list_snippets_returns_pagination(app, monkeypatch):
|
||||
def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippets = [_snippet()]
|
||||
tag_id = "11111111-1111-1111-1111-111111111111"
|
||||
get_snippets = Mock(return_value=(snippets, 1, False))
|
||||
@ -104,7 +105,7 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch):
|
||||
def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
create_snippet = Mock(return_value=snippet)
|
||||
@ -140,7 +141,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
|
||||
assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE
|
||||
|
||||
|
||||
def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
|
||||
def test_create_snippet_rejects_forbidden_nodes(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
create_snippet = Mock()
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
|
||||
@ -169,7 +170,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
|
||||
create_snippet.assert_not_called()
|
||||
|
||||
|
||||
def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
|
||||
def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
@ -180,7 +181,7 @@ def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_get_snippet_detail_returns_snippet(app, monkeypatch):
|
||||
def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
|
||||
@ -195,7 +196,7 @@ def test_get_snippet_detail_returns_snippet(app, monkeypatch):
|
||||
assert response == {"id": "snippet-1"}
|
||||
|
||||
|
||||
def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
|
||||
def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet()
|
||||
user = _account("user-1")
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
@ -214,7 +215,7 @@ def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
|
||||
assert response == {"message": "No valid fields to update"}
|
||||
|
||||
|
||||
def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
updated_snippet = _snippet(name="New")
|
||||
@ -251,7 +252,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_snippet_deletes_and_commits(app, monkeypatch):
|
||||
def test_delete_snippet_deletes_and_commits(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet()
|
||||
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
|
||||
delete_snippet = Mock()
|
||||
@ -277,7 +278,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
def test_export_snippet_returns_yaml_attachment(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet(name="Snippet One")
|
||||
export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n")
|
||||
session = SimpleNamespace()
|
||||
@ -308,7 +309,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
export_snippet_dsl.assert_called_once_with(snippet=snippet, include_secret=True)
|
||||
|
||||
|
||||
def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
def test_import_snippet_returns_202_for_pending_confirmation(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0")
|
||||
import_snippet = Mock(return_value=result)
|
||||
@ -348,7 +349,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
def test_import_snippet_returns_400_for_failed_import(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL")
|
||||
import_snippet = Mock(return_value=result)
|
||||
@ -381,7 +382,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
def test_import_confirm_returns_200_for_completed_import(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1")
|
||||
confirm_import = Mock(return_value=result)
|
||||
@ -414,7 +415,7 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
|
||||
def test_check_dependencies_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
|
||||
@ -425,7 +426,7 @@ def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
def test_check_dependencies_returns_dependency_result(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet()
|
||||
check_dependencies = Mock(
|
||||
return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []}))
|
||||
@ -456,7 +457,7 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
check_dependencies.assert_called_once_with(snippet=snippet)
|
||||
|
||||
|
||||
def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
|
||||
def test_increment_use_count_raises_when_snippet_missing(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
|
||||
@ -470,7 +471,7 @@ def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
|
||||
def test_increment_use_count_returns_refreshed_count(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=2)
|
||||
merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1", use_count=3)
|
||||
session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock(), refresh=Mock())
|
||||
|
||||
@ -35,7 +35,7 @@ def _stub_execute(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
def bypass_pipeline(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
|
||||
|
||||
Module-level @auth_router.guard(...) captures the real router at import
|
||||
|
||||
@ -167,14 +167,14 @@ def _session_auth_data() -> AuthData:
|
||||
)
|
||||
|
||||
|
||||
def _stub_session_deps(monkeypatch, rows):
|
||||
def _stub_session_deps(monkeypatch: pytest.MonkeyPatch, rows):
|
||||
mod = sys.modules[_ACCOUNT_MOD]
|
||||
monkeypatch.setattr(mod, "get_auth_ctx", lambda: SimpleNamespace())
|
||||
monkeypatch.setattr(mod, "list_active_sessions", lambda *args, **kwargs: rows)
|
||||
monkeypatch.setattr(mod, "db", MagicMock())
|
||||
|
||||
|
||||
def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch):
|
||||
def test_sessions_list_valid_query_parses_page_and_limit(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
"""A valid ?page&limit round-trips through SessionListQuery into the response envelope."""
|
||||
api = AccountSessionsApi()
|
||||
_stub_session_deps(monkeypatch, [])
|
||||
@ -187,7 +187,7 @@ def test_sessions_list_valid_query_parses_page_and_limit(app, monkeypatch):
|
||||
assert body["data"] == []
|
||||
|
||||
|
||||
def test_sessions_list_defaults_when_query_omitted(app, monkeypatch):
|
||||
def test_sessions_list_defaults_when_query_omitted(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
"""No query → the model's defaults (page=1, limit=100) drive the envelope."""
|
||||
api = AccountSessionsApi()
|
||||
_stub_session_deps(monkeypatch, [])
|
||||
@ -209,7 +209,7 @@ def test_sessions_list_defaults_when_query_omitted(app, monkeypatch):
|
||||
"foo=bar", # extra='forbid'
|
||||
],
|
||||
)
|
||||
def test_sessions_list_rejects_out_of_bounds_query(app, monkeypatch, query):
|
||||
def test_sessions_list_rejects_out_of_bounds_query(app: Flask, monkeypatch: pytest.MonkeyPatch, query):
|
||||
"""Out-of-range / unknown query params raise 422 instead of being silently coerced."""
|
||||
api = AccountSessionsApi()
|
||||
_stub_session_deps(monkeypatch, [])
|
||||
|
||||
@ -6,6 +6,9 @@ import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
|
||||
|
||||
@ -30,7 +33,9 @@ def test_app_run_request_with_query():
|
||||
assert req.query == "hello"
|
||||
|
||||
|
||||
def test_run_chat_always_calls_generate_with_streaming_true(app, bypass_pipeline, monkeypatch):
|
||||
def test_run_chat_always_calls_generate_with_streaming_true(
|
||||
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""_run_chat must always invoke AppGenerateService.generate with streaming=True."""
|
||||
from controllers.openapi.app_run import _run_chat
|
||||
|
||||
@ -56,7 +61,7 @@ def test_stop_task_endpoint_registered(openapi_app):
|
||||
assert "/openapi/v1/apps/<string:app_id>/tasks/<string:task_id>/stop" in rules
|
||||
|
||||
|
||||
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
|
||||
def test_stop_task_calls_queue_manager_and_graph_engine(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
import uuid
|
||||
|
||||
from controllers.openapi.app_run import AppRunTaskStopApi
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_version_endpoint_returns_200_without_auth(openapi_app):
|
||||
client = openapi_app.test_client()
|
||||
@ -30,7 +32,7 @@ def test_version_endpoint_ignores_bearer_header(openapi_app):
|
||||
assert "edition" in payload
|
||||
|
||||
|
||||
def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch):
|
||||
def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch: pytest.MonkeyPatch):
|
||||
from configs import dify_config
|
||||
|
||||
monkeypatch.setattr(dify_config, "EDITION", "CLOUD")
|
||||
@ -42,7 +44,7 @@ def test_version_endpoint_reflects_edition_config(openapi_app, monkeypatch):
|
||||
assert response.get_json()["edition"] == "CLOUD"
|
||||
|
||||
|
||||
def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch):
|
||||
def test_version_endpoint_falls_back_to_self_hosted_on_unexpected_edition(openapi_app, monkeypatch: pytest.MonkeyPatch):
|
||||
from configs import dify_config
|
||||
|
||||
monkeypatch.setattr(dify_config, "EDITION", "EXPERIMENTAL")
|
||||
|
||||
@ -8,6 +8,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -51,7 +52,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
|
||||
return OpenApiWorkflowEventsApi()
|
||||
|
||||
def test_not_found_when_run_missing(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_not_found_when_run_missing(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
module = sys.modules["controllers.openapi.workflow_events"]
|
||||
repo_mock = Mock()
|
||||
repo_mock.get_workflow_run_by_id_and_tenant_id.return_value = None
|
||||
@ -76,7 +77,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_not_found_when_run_belongs_to_different_app(
|
||||
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
module = sys.modules["controllers.openapi.workflow_events"]
|
||||
run = _make_workflow_run(app_id="other-app")
|
||||
repo_mock = Mock()
|
||||
@ -102,7 +105,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_account_caller_checks_created_by_account(
|
||||
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Account caller must match created_by == caller.id and role == ACCOUNT."""
|
||||
module = sys.modules["controllers.openapi.workflow_events"]
|
||||
run = _make_workflow_run(created_by_role=CreatorUserRole.ACCOUNT, created_by="acct-1")
|
||||
@ -141,7 +146,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
def test_account_caller_rejected_for_end_user_run(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_account_caller_rejected_for_end_user_run(
|
||||
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
module = sys.modules["controllers.openapi.workflow_events"]
|
||||
run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1")
|
||||
repo_mock = Mock()
|
||||
@ -167,7 +174,9 @@ class TestOpenApiWorkflowEventsApi:
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_end_user_caller_checks_created_by_end_user(
|
||||
self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""End-user caller must match created_by == caller.id and role == END_USER."""
|
||||
module = sys.modules["controllers.openapi.workflow_events"]
|
||||
run = _make_workflow_run(created_by_role=CreatorUserRole.END_USER, created_by="eu-1")
|
||||
@ -202,7 +211,7 @@ class TestOpenApiWorkflowEventsApi:
|
||||
)
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
def test_finished_run_returns_single_sse_event(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_finished_run_returns_single_sse_event(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""A finished run returns a single done-event SSE response without streaming."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi._errors import MemberLicenseExceeded, MemberLimitExceeded
|
||||
from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.workspaces import (
|
||||
WorkspaceMemberApi,
|
||||
WorkspaceMemberRoleApi,
|
||||
@ -228,7 +229,7 @@ def test_role_payload_rejects_extra_field():
|
||||
MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"})
|
||||
|
||||
|
||||
def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
def test_invite_rejects_invalid_body_with_422(app: Flask, bypass_pipeline):
|
||||
"""Invalid invite body → 422 via @accepts (was 400 through _validate_body)."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -245,7 +246,7 @@ def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
|
||||
|
||||
|
||||
def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
def test_update_role_rejects_invalid_body_with_422(app: Flask, bypass_pipeline):
|
||||
"""Invalid role-update body surfaces as 422 through @accepts (was 400)."""
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -267,7 +268,9 @@ def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, monkeypatch):
|
||||
def test_switch_returns_workspace_detail_with_current_true(
|
||||
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Happy path: switch service is called, then the workspace+membership
|
||||
row is re-queried so the returned `current` reflects post-commit state.
|
||||
"""
|
||||
@ -298,7 +301,9 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline,
|
||||
assert switch_mock.called
|
||||
|
||||
|
||||
def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pipeline, monkeypatch):
|
||||
def test_switch_404s_when_service_raises_account_not_link_tenant(
|
||||
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""If switch_tenant raises (e.g. Tenant.status != NORMAL), the body
|
||||
surfaces as NotFound, not 500."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
@ -326,7 +331,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch):
|
||||
def test_members_list_returns_normalized_rows(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMembersApi()
|
||||
@ -364,7 +369,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch)
|
||||
assert body["data"][0]["status"] == "active"
|
||||
|
||||
|
||||
def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypatch):
|
||||
def test_members_list_paginates_with_query_params(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""`?page=2&limit=2` slices service output and reports total/has_more."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -404,7 +409,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
|
||||
assert [d["id"] for d in body["data"]] == ["m-2", "m-3"]
|
||||
|
||||
|
||||
def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch):
|
||||
def test_members_list_rejects_unknown_query_param(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts)."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -425,7 +430,9 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_happy_path_returns_invite_url_and_member_id(
|
||||
app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMembersApi()
|
||||
@ -507,7 +514,7 @@ def _invite_request(app, ws_id: str, acct_id: uuid.UUID):
|
||||
)
|
||||
|
||||
|
||||
def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_blocked_by_saas_members_cap(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""SaaS billing plan member cap → MemberLimitExceeded (403)."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -541,7 +548,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch):
|
||||
invite_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_blocked_by_ee_workspace_members_license(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""EE License workspace_members cap → MemberLicenseExceeded (403).
|
||||
|
||||
Note: billing.enabled is False (EE without SaaS billing); only the
|
||||
@ -583,7 +590,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo
|
||||
invite_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_ce_passes_when_both_caps_disabled(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""CE deployment (no billing, no license) → quota gate is a no-op,
|
||||
invite proceeds normally."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
@ -619,7 +626,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa
|
||||
assert body["email"] == "new@example.com"
|
||||
|
||||
|
||||
def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_400_when_already_in_tenant(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMembersApi()
|
||||
@ -650,7 +657,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
def test_delete_member_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberApi()
|
||||
@ -692,7 +699,7 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
(MemberNotInTenantError("not in tenant"), NotFound),
|
||||
],
|
||||
)
|
||||
def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected):
|
||||
def test_delete_member_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected):
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberApi()
|
||||
@ -725,7 +732,7 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch):
|
||||
def test_delete_member_404_when_member_missing(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberApi()
|
||||
@ -757,7 +764,7 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
def test_update_role_happy_path(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberRoleApi()
|
||||
@ -801,7 +808,7 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch):
|
||||
(MemberNotInTenantError("not in tenant"), NotFound),
|
||||
],
|
||||
)
|
||||
def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected):
|
||||
def test_update_role_exception_mapping(app: Flask, bypass_pipeline, monkeypatch, exc, expected):
|
||||
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
api = WorkspaceMemberRoleApi()
|
||||
@ -841,7 +848,7 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatch):
|
||||
def test_load_tenant_rejects_archived_workspace(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Member management against an archived workspace → 404."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
@ -869,7 +876,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch):
|
||||
def test_invite_400_when_register_error(app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
"""AccountRegisterError (frozen email, workspace creation blocked) → 400."""
|
||||
ws_id = str(uuid.uuid4())
|
||||
acct_id = uuid.uuid4()
|
||||
|
||||
@ -81,7 +81,7 @@ class TestPluginAgentStrategyInitialization:
|
||||
|
||||
|
||||
class TestGetParameters:
|
||||
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
|
||||
def test_get_parameters_returns_parameters(self, strategy: PluginAgentStrategy, mock_declaration) -> None:
|
||||
result = strategy.get_parameters()
|
||||
assert result == mock_declaration.parameters
|
||||
|
||||
@ -92,7 +92,7 @@ class TestGetParameters:
|
||||
|
||||
|
||||
class TestInitializeParameters:
|
||||
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
|
||||
def test_initialize_parameters_success(self, strategy: PluginAgentStrategy, mock_declaration) -> None:
|
||||
params = {"param1": "value1"}
|
||||
|
||||
result = strategy.initialize_parameters(params.copy())
|
||||
@ -114,13 +114,13 @@ class TestInitializeParameters:
|
||||
{"param1": {}, "param2": "value"},
|
||||
],
|
||||
)
|
||||
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
|
||||
def test_initialize_parameters_edge_cases(self, strategy: PluginAgentStrategy, input_params) -> None:
|
||||
result = strategy.initialize_parameters(input_params.copy())
|
||||
|
||||
for param in strategy.declaration.parameters:
|
||||
assert param.name in result
|
||||
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy: PluginAgentStrategy) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.initialize_parameters(None)
|
||||
|
||||
@ -131,7 +131,7 @@ class TestInitializeParameters:
|
||||
|
||||
|
||||
class TestInvoke:
|
||||
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
|
||||
def test_invoke_success_all_arguments(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
|
||||
|
||||
@ -171,7 +171,7 @@ class TestInvoke:
|
||||
assert call_kwargs["message_id"] == "msg_1"
|
||||
assert call_kwargs["context"] is not None
|
||||
|
||||
def test_invoke_with_credentials(self, strategy, mocker) -> None:
|
||||
def test_invoke_with_credentials(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
@ -243,7 +243,7 @@ class TestInvoke:
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
|
||||
def test_invoke_convert_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=MagicMock(),
|
||||
@ -257,7 +257,7 @@ class TestInvoke:
|
||||
with pytest.raises(ValueError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
|
||||
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
|
||||
def test_invoke_manager_raises_exception(self, strategy: PluginAgentStrategy, mocker: MockerFixture) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
|
||||
|
||||
|
||||
@ -42,13 +42,13 @@ def runner(mocker: MockerFixture, mock_db_session):
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker: MockerFixture):
|
||||
def test_sets_empty_if_none(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker: MockerFixture):
|
||||
def test_keeps_existing(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
@ -61,7 +61,7 @@ class TestRepack:
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner, mocker: MockerFixture):
|
||||
def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
schema = {
|
||||
"type": "object",
|
||||
@ -83,7 +83,7 @@ class TestUpdatePromptTool:
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_with_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
@ -91,7 +91,7 @@ class TestCreateAgentThought:
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_without_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
@ -112,12 +112,12 @@ class TestSaveAgentThought:
|
||||
agent.thought = ""
|
||||
return agent
|
||||
|
||||
def test_not_found(self, runner, mock_db_session):
|
||||
def test_not_found(self, runner: BaseAgentRunner, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_full_update(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -152,7 +152,7 @@ class TestSaveAgentThought:
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_label_fallback_when_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
@ -162,7 +162,7 @@ class TestSaveAgentThought:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_json_failure_paths(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -183,13 +183,13 @@ class TestSaveAgentThought:
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_messages_ids_none(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_success_dict_serialization(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -215,19 +215,19 @@ class TestSaveAgentThought:
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_no_files(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_with_files_no_config(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_image_detail_low_fallback(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
@ -247,27 +247,27 @@ class TestOrganizeUserPrompt:
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_empty(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_with_answer_only(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_skip_current_message(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_with_tool_calls_invalid_json(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
@ -283,7 +283,7 @@ class TestOrganizeHistory:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_empty_tool_name_split(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
@ -292,7 +292,7 @@ class TestOrganizeHistory:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_valid_json_tool_flow(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
@ -321,7 +321,7 @@ class TestOrganizeHistory:
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker: MockerFixture):
|
||||
def test_basic_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
@ -347,7 +347,7 @@ class TestConvertToolToPromptMessageTool:
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker: MockerFixture):
|
||||
def test_agent_tool_branch(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
@ -355,7 +355,7 @@ class TestInitPromptToolsExtended:
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker: MockerFixture):
|
||||
def test_exception_in_conversion(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
@ -370,7 +370,7 @@ class TestInitPromptToolsExtended:
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_save_agent_thought_existing_labels(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
@ -381,7 +381,7 @@ class TestAdditionalCoverage:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_save_agent_thought_tool_meta_string(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
@ -391,7 +391,7 @@ class TestAdditionalCoverage:
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker: MockerFixture):
|
||||
def test_convert_dataset_retriever_tool(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
@ -408,7 +408,9 @@ class TestAdditionalCoverage:
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_organize_user_prompt_with_file_objects(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
@ -427,7 +429,7 @@ class TestAdditionalCoverage:
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_organize_history_without_tool_names(self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
@ -437,7 +439,9 @@ class TestAdditionalCoverage:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_organize_history_multiple_tools_split(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
@ -455,7 +459,7 @@ class TestAdditionalCoverage:
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker: MockerFixture):
|
||||
def test_required_param_added(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
@ -518,7 +522,7 @@ class TestBaseAgentRunnerInit:
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture):
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner: BaseAgentRunner, mocker: MockerFixture):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
@ -530,7 +534,9 @@ class TestBaseAgentRunnerCoverage:
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_save_agent_thought_json_dumps_fallbacks(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
@ -568,7 +574,9 @@ class TestBaseAgentRunnerCoverage:
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_save_agent_thought_skips_empty_tool_name(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
@ -582,7 +590,9 @@ class TestBaseAgentRunnerCoverage:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_organize_history_includes_system_prompt(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
@ -592,7 +602,9 @@ class TestBaseAgentRunnerCoverage:
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
def test_organize_history_tool_inputs_and_observation_none(
|
||||
self, runner: BaseAgentRunner, mock_db_session, mocker: MockerFixture
|
||||
):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
|
||||
@ -95,25 +95,25 @@ class TestFillInputs:
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
def test_fill_in_inputs(self, runner: DummyRunner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
def test_convert_valid_dict(self, runner: DummyRunner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
def test_convert_missing_keys(self, runner: DummyRunner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner: DummyRunner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
@ -131,7 +131,7 @@ class TestFormatAssistantMessage:
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
def test_format_with_final(self, runner: DummyRunner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
@ -144,7 +144,7 @@ class TestFormatAssistantMessage:
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
def test_format_with_action_and_observation(self, runner: DummyRunner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
@ -161,12 +161,12 @@ class TestFormatAssistantMessage:
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
def test_handle_invoke_action_tool_not_present(self, runner: DummyRunner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker: MockerFixture):
|
||||
def test_tool_with_json_string_args(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
@ -181,7 +181,7 @@ class TestHandleInvokeAction:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker: MockerFixture):
|
||||
def test_empty_history(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
@ -191,7 +191,7 @@ class TestOrganizeHistoricPromptMessages:
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker: MockerFixture):
|
||||
def test_run_handles_empty_parser_output(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -203,7 +203,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_action_and_tool_invocation(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -224,7 +224,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker: MockerFixture):
|
||||
def test_run_respects_max_iteration_boundary(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
@ -246,7 +246,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker: MockerFixture):
|
||||
def test_run_basic_flow(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -258,7 +258,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker: MockerFixture):
|
||||
def test_run_max_iteration_error(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
@ -273,7 +273,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker: MockerFixture):
|
||||
def test_run_increase_usage_aggregation(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
@ -330,7 +330,7 @@ class TestRun:
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker: MockerFixture):
|
||||
def test_run_when_no_action_branch(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -342,7 +342,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker: MockerFixture):
|
||||
def test_run_usage_missing_key_branch(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -355,7 +355,7 @@ class TestRun:
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker: MockerFixture):
|
||||
def test_run_prompt_tool_update_branch(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -387,7 +387,7 @@ class TestRun:
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner: DummyRunner):
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
@ -400,7 +400,7 @@ class TestRun:
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
def test_historic_final_flush_branch(self, runner: DummyRunner):
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
@ -411,7 +411,7 @@ class TestRun:
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker: MockerFixture):
|
||||
def test_init_react_state_resets_state(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
@ -424,7 +424,7 @@ class TestInitReactState:
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker: MockerFixture):
|
||||
def test_tool_with_invalid_json_string_args(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
@ -443,11 +443,11 @@ class TestHandleInvokeActionExtended:
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
def test_fill_inputs_with_empty_inputs(self, runner: DummyRunner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner: DummyRunner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
@ -458,7 +458,7 @@ class TestFillInputsEdgeCases:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker: MockerFixture):
|
||||
def test_user_message_flushes_scratchpad(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
from graphon.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
def test_tool_message_without_scratchpad_raises(self, runner: DummyRunner):
|
||||
from graphon.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
@ -481,7 +481,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker: MockerFixture):
|
||||
def test_agent_history_transform_invocation(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
@ -496,7 +496,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_no_action_final_answer_empty(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -508,7 +508,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_final_answer_action_string(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -522,7 +522,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_final_answer_action_dict(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -536,7 +536,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_string_final_answer(self, runner: DummyRunner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
|
||||
@ -40,7 +40,11 @@ def runner(mocker: MockerFixture, dummy_tool_factory):
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
self,
|
||||
runner: CotCompletionAgentRunner,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
@ -58,12 +62,14 @@ class TestOrganizeInstructionPrompt:
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
def test_agent_none_raises(self, runner: CotCompletionAgentRunner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
def test_prompt_entity_none_raises(
|
||||
self, runner: CotCompletionAgentRunner, dummy_app_config_factory, dummy_agent_config_factory
|
||||
):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
@ -75,7 +81,7 @@ class TestOrganizeInstructionPrompt:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker: MockerFixture):
|
||||
def test_with_user_and_assistant_string(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
@ -90,7 +96,7 @@ class TestOrganizeHistoricPrompt:
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker: MockerFixture):
|
||||
def test_assistant_list_with_text_content(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
@ -104,7 +110,9 @@ class TestOrganizeHistoricPrompt:
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker: MockerFixture):
|
||||
def test_assistant_list_with_non_text_content_ignored(
|
||||
self, runner: CotCompletionAgentRunner, mocker: MockerFixture
|
||||
):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
@ -117,7 +125,7 @@ class TestOrganizeHistoricPrompt:
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker: MockerFixture):
|
||||
def test_empty_history(self, runner: CotCompletionAgentRunner, mocker: MockerFixture):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
|
||||
@ -147,12 +147,12 @@ def runner(mocker: MockerFixture):
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
def test_check_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
def test_check_blocking_tool_calls(self, runner: FunctionCallAgentRunner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
@ -163,7 +163,7 @@ class TestToolCallChecks:
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
def test_extract_tool_calls_with_valid_json(self, runner: FunctionCallAgentRunner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
@ -174,7 +174,7 @@ class TestExtractToolCalls:
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
def test_extract_tool_calls_empty_arguments(self, runner: FunctionCallAgentRunner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
@ -185,7 +185,7 @@ class TestExtractToolCalls:
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
def test_extract_blocking_tool_calls(self, runner: FunctionCallAgentRunner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
@ -203,16 +203,16 @@ class TestExtractToolCalls:
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
def test_init_system_message_empty_prompt_messages(self, runner: FunctionCallAgentRunner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
def test_init_system_message_insert_at_start(self, runner: FunctionCallAgentRunner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
def test_init_system_message_no_template(self, runner: FunctionCallAgentRunner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
@ -223,15 +223,15 @@ class TestInitSystemMessage:
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
def test_without_files(self, runner: FunctionCallAgentRunner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
def test_with_none_query(self, runner: FunctionCallAgentRunner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker: MockerFixture):
|
||||
def test_with_files_uses_image_detail_config(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
@ -255,7 +255,7 @@ class TestOrganizeUserQuery:
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
def test_clear_text_and_image_content(self, runner: FunctionCallAgentRunner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
@ -271,7 +271,7 @@ class TestClearUserPromptImageMessages:
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
def test_clear_includes_file_placeholder(self, runner: FunctionCallAgentRunner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
@ -289,7 +289,7 @@ class TestClearUserPromptImageMessages:
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
def test_run_non_streaming_no_tool_calls(self, runner: FunctionCallAgentRunner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
@ -303,7 +303,7 @@ class TestRunMethod:
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
def test_run_streaming_branch(self, runner: FunctionCallAgentRunner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
@ -318,7 +318,7 @@ class TestRunMethod:
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
def test_run_streaming_tool_calls_list_content(self, runner: FunctionCallAgentRunner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
@ -341,7 +341,7 @@ class TestRunMethod:
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
def test_run_non_streaming_list_content(self, runner: FunctionCallAgentRunner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
@ -353,7 +353,7 @@ class TestRunMethod:
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker: MockerFixture):
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
@ -381,7 +381,7 @@ class TestRunMethod:
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
def test_run_with_missing_tool_instance(self, runner: FunctionCallAgentRunner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
@ -399,7 +399,7 @@ class TestRunMethod:
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker: MockerFixture):
|
||||
def test_run_with_tool_instance_and_files(self, runner: FunctionCallAgentRunner, mocker: MockerFixture):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
@ -434,7 +434,7 @@ class TestRunMethod:
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
def test_run_max_iteration_error(self, runner: FunctionCallAgentRunner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
@ -36,7 +36,7 @@ def generator(mocker: MockerFixture) -> AgentAppGenerator:
|
||||
|
||||
|
||||
class TestGenerateGuards:
|
||||
def test_rejects_blocking_mode(self, generator, mocker: MockerFixture):
|
||||
def test_rejects_blocking_mode(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
with pytest.raises(AgentAppGeneratorError, match="only supports streaming"):
|
||||
generator.generate(
|
||||
app_model=mocker.MagicMock(),
|
||||
@ -46,7 +46,7 @@ class TestGenerateGuards:
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_requires_query(self, generator, mocker: MockerFixture):
|
||||
def test_requires_query(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
with pytest.raises(AgentAppGeneratorError, match="query is required"):
|
||||
generator.generate(
|
||||
app_model=mocker.MagicMock(),
|
||||
@ -55,7 +55,7 @@ class TestGenerateGuards:
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_rejects_blank_query(self, generator, mocker: MockerFixture):
|
||||
def test_rejects_blank_query(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
with pytest.raises(AgentAppGeneratorError, match="query is required"):
|
||||
generator.generate(
|
||||
app_model=mocker.MagicMock(),
|
||||
@ -113,7 +113,7 @@ class TestGenerateSuccess:
|
||||
thread_obj.start.assert_called_once()
|
||||
generator._resolve_agent.assert_called_once_with(app_model)
|
||||
|
||||
def test_generate_loads_existing_conversation(self, generator, mocker: MockerFixture):
|
||||
def test_generate_loads_existing_conversation(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent")
|
||||
generator._resolve_agent = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="a"), mocker.MagicMock(id="s"), mocker.MagicMock())
|
||||
@ -144,7 +144,9 @@ class TestGenerateSuccess:
|
||||
|
||||
get_conv.assert_called_once()
|
||||
|
||||
def test_generate_does_not_include_trace_session_id_in_extras(self, generator, mocker: MockerFixture):
|
||||
def test_generate_does_not_include_trace_session_id_in_extras(
|
||||
self, generator: AgentAppGenerator, mocker: MockerFixture
|
||||
):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent")
|
||||
user = DummyAccount("user")
|
||||
|
||||
@ -190,7 +192,7 @@ class TestGenerateWorker:
|
||||
|
||||
mocker.patch("libs.flask_utils.preserve_flask_contexts", ctx_manager)
|
||||
|
||||
def _wire(self, generator, mocker: MockerFixture, *, run_side_effect=None, handled=False):
|
||||
def _wire(self, generator: AgentAppGenerator, mocker: MockerFixture, *, run_side_effect=None, handled=False):
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock(id="conv"))
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock(id="msg"))
|
||||
generator._run_input_guards = mocker.MagicMock(return_value=(handled, "query"))
|
||||
@ -238,7 +240,7 @@ class TestGenerateWorker:
|
||||
is_resume=is_resume,
|
||||
)
|
||||
|
||||
def test_happy_path_runs_backend(self, generator, mocker: MockerFixture):
|
||||
def test_happy_path_runs_backend(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
runner = self._wire(generator, mocker)
|
||||
queue_manager = mocker.MagicMock()
|
||||
self._call(generator, mocker, queue_manager)
|
||||
@ -280,7 +282,7 @@ class TestGenerateWorker:
|
||||
self._call(generator, mocker, queue_manager)
|
||||
queue_manager.publish_error.assert_not_called()
|
||||
|
||||
def test_unexpected_error_is_published(self, generator, mocker: MockerFixture):
|
||||
def test_unexpected_error_is_published(self, generator: AgentAppGenerator, mocker: MockerFixture):
|
||||
self._wire(generator, mocker, run_side_effect=ValueError("boom"))
|
||||
queue_manager = mocker.MagicMock()
|
||||
self._call(generator, mocker, queue_manager)
|
||||
|
||||
@ -14,7 +14,7 @@ def runner():
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner, mocker: MockerFixture):
|
||||
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
@ -23,7 +23,7 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_moderation_error_direct_output(self, runner, mocker: MockerFixture):
|
||||
def test_run_moderation_error_direct_output(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
@ -46,7 +46,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self, runner, mocker: MockerFixture):
|
||||
def test_run_annotation_reply_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
@ -75,7 +75,7 @@ class TestAgentChatAppRunnerRun:
|
||||
queue_manager.publish.assert_called_once()
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_hosting_moderation_short_circuits(self, runner, mocker: MockerFixture):
|
||||
def test_run_hosting_moderation_short_circuits(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
@ -99,7 +99,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_model_schema_missing(self, runner, mocker: MockerFixture):
|
||||
def test_run_model_schema_missing(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -141,7 +141,7 @@ class TestAgentChatAppRunnerRun:
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker: MockerFixture, mode, expected_runner):
|
||||
def test_run_chain_of_thought_modes(self, runner: AgentChatAppRunner, mocker: MockerFixture, mode, expected_runner):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -197,7 +197,7 @@ class TestAgentChatAppRunnerRun:
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker: MockerFixture):
|
||||
def test_run_invalid_llm_mode_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -243,7 +243,9 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker: MockerFixture):
|
||||
def test_run_function_calling_strategy_selected_by_features(
|
||||
self, runner: AgentChatAppRunner, mocker: MockerFixture
|
||||
):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -299,7 +301,7 @@ class TestAgentChatAppRunnerRun:
|
||||
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_conversation_not_found(self, runner, mocker: MockerFixture):
|
||||
def test_run_conversation_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
@ -333,7 +335,7 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_message_not_found(self, runner, mocker: MockerFixture):
|
||||
def test_run_message_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
@ -367,7 +369,7 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker: MockerFixture):
|
||||
def test_run_invalid_agent_strategy_raises(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
|
||||
@ -112,7 +112,7 @@ def test_run_uses_single_node_execution_branch(
|
||||
assert entry_kwargs["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
def test_single_node_run_validates_target_node_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runner = WorkflowBasedAppRunner(
|
||||
queue_manager=MagicMock(spec=AppQueueManager),
|
||||
variable_loader=MagicMock(),
|
||||
|
||||
@ -218,7 +218,7 @@ class TestWorkflowPersistenceLayer:
|
||||
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
|
||||
assert trace_tasks
|
||||
|
||||
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch):
|
||||
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_tasks: list[TraceTask] = []
|
||||
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
|
||||
layer, _, _, _ = _make_layer(
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
@ -112,7 +113,7 @@ def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope
|
||||
assert "upload_files.id IN" in whereclause
|
||||
|
||||
|
||||
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None:
|
||||
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
tenant_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
|
||||
@ -64,7 +64,9 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result):
|
||||
def test_embed_single_multimodal_document_cache_miss(
|
||||
self, mock_model_instance, sample_multimodal_result: EmbeddingResult
|
||||
):
|
||||
"""Test embedding a single multimodal document when cache is empty."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": "file123", "content": "test content"}]
|
||||
|
||||
@ -97,7 +97,9 @@ class TestRerankModelRunner:
|
||||
),
|
||||
]
|
||||
|
||||
def test_basic_reranking(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_basic_reranking(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
|
||||
):
|
||||
"""Test basic reranking with cross-encoder model.
|
||||
|
||||
Verifies:
|
||||
@ -135,7 +137,9 @@ class TestRerankModelRunner:
|
||||
assert result[3].metadata["score"] == 0.65
|
||||
assert result[0].page_content == sample_documents[2].page_content
|
||||
|
||||
def test_score_threshold_filtering(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_score_threshold_filtering(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
|
||||
):
|
||||
"""Test score threshold filtering.
|
||||
|
||||
Verifies:
|
||||
@ -163,7 +167,9 @@ class TestRerankModelRunner:
|
||||
assert result[0].metadata["score"] == 0.90
|
||||
assert result[1].metadata["score"] == 0.70
|
||||
|
||||
def test_top_k_selection(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_top_k_selection(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
|
||||
):
|
||||
"""Test top-k selection functionality.
|
||||
|
||||
Verifies:
|
||||
@ -191,7 +197,7 @@ class TestRerankModelRunner:
|
||||
assert result[0].metadata["score"] == 0.95
|
||||
assert result[1].metadata["score"] == 0.85
|
||||
|
||||
def test_document_deduplication_dify_provider(self, rerank_runner, mock_model_instance):
|
||||
def test_document_deduplication_dify_provider(self, rerank_runner: RerankModelRunner, mock_model_instance):
|
||||
"""Test document deduplication for dify provider.
|
||||
|
||||
Verifies:
|
||||
@ -235,7 +241,7 @@ class TestRerankModelRunner:
|
||||
assert len(call_kwargs["docs"]) == 2 # Duplicate removed
|
||||
assert len(result) == 2
|
||||
|
||||
def test_document_deduplication_external_provider(self, rerank_runner, mock_model_instance):
|
||||
def test_document_deduplication_external_provider(self, rerank_runner: RerankModelRunner, mock_model_instance):
|
||||
"""Test document deduplication for external provider.
|
||||
|
||||
Verifies:
|
||||
@ -273,7 +279,9 @@ class TestRerankModelRunner:
|
||||
assert len(call_kwargs["docs"]) == 2
|
||||
assert len(result) == 2
|
||||
|
||||
def test_combined_threshold_and_top_k(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_combined_threshold_and_top_k(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
|
||||
):
|
||||
"""Test combined score threshold and top-k selection.
|
||||
|
||||
Verifies:
|
||||
@ -307,7 +315,9 @@ class TestRerankModelRunner:
|
||||
assert result[0].metadata["score"] == 0.95
|
||||
assert result[1].metadata["score"] == 0.85
|
||||
|
||||
def test_metadata_preservation(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_metadata_preservation(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance, sample_documents: list[Document]
|
||||
):
|
||||
"""Test that original metadata is preserved after reranking.
|
||||
|
||||
Verifies:
|
||||
@ -334,7 +344,7 @@ class TestRerankModelRunner:
|
||||
assert result[0].metadata["score"] == 0.90
|
||||
assert result[0].provider == "dify"
|
||||
|
||||
def test_empty_documents_list(self, rerank_runner, mock_model_instance):
|
||||
def test_empty_documents_list(self, rerank_runner: RerankModelRunner, mock_model_instance):
|
||||
"""Test handling of empty documents list.
|
||||
|
||||
Verifies:
|
||||
@ -523,7 +533,9 @@ class TestRerankModelRunnerMultimodal:
|
||||
docs_arg = mock_text_rerank.call_args.args[1]
|
||||
assert len(docs_arg) == 1
|
||||
|
||||
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance):
|
||||
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(
|
||||
self, rerank_runner: RerankModelRunner, mock_model_instance
|
||||
):
|
||||
text_doc = Document(
|
||||
page_content="text-content",
|
||||
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
|
||||
|
||||
@ -413,7 +413,7 @@ class TestWorkflowGeneratorTransientRetry:
|
||||
}
|
||||
)
|
||||
|
||||
def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch):
|
||||
def test_retries_transient_invoke_error_then_succeeds(self, monkeypatch: pytest.MonkeyPatch):
|
||||
# The planner's first invoke raises a transient connection error; the
|
||||
# retry succeeds and the pipeline completes normally. Sleep is patched
|
||||
# out so the test doesn't actually wait for the backoff.
|
||||
@ -443,7 +443,7 @@ class TestWorkflowGeneratorTransientRetry:
|
||||
# planner (failed once + retried) + builder = 3 invocations total.
|
||||
assert model_instance.invoke_llm.call_count == 3
|
||||
|
||||
def test_gives_up_after_exhausting_transient_retries(self, monkeypatch):
|
||||
def test_gives_up_after_exhausting_transient_retries(self, monkeypatch: pytest.MonkeyPatch):
|
||||
# Every attempt hits the transient error — once we exhaust the retry
|
||||
# budget the failure surfaces as a normal error envelope rather than
|
||||
# hanging or looping forever.
|
||||
@ -471,7 +471,7 @@ class TestWorkflowGeneratorTransientRetry:
|
||||
|
||||
assert model_instance.invoke_llm.call_count == _INVOKE_MAX_ATTEMPTS
|
||||
|
||||
def test_does_not_retry_permanent_invoke_error(self, monkeypatch):
|
||||
def test_does_not_retry_permanent_invoke_error(self, monkeypatch: pytest.MonkeyPatch):
|
||||
# An auth error is permanent — retrying just burns latency and quota.
|
||||
# The runner must fail on the first attempt.
|
||||
# If the code wrongly slept here we'd want the test to still be fast;
|
||||
|
||||
@ -332,7 +332,9 @@ def test_email_delivery_method_extracts_variable_selectors() -> None:
|
||||
assert method.extract_variable_selectors() == [["start", "name"]]
|
||||
|
||||
|
||||
def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors(monkeypatch) -> None:
|
||||
def test_email_delivery_method_extracts_variable_selectors_skips_short_selectors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
method = EmailDeliveryMethod(
|
||||
enabled=True,
|
||||
config=EmailDeliveryConfig(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
|
||||
def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_pandas_csv(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
@ -17,7 +19,7 @@ def test_pandas_csv(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_pandas_xlsx(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
@ -32,7 +34,7 @@ def test_pandas_xlsx(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_pandas_xlsx_with_sheets(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data1)
|
||||
|
||||
@ -75,7 +75,7 @@ def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build):
|
||||
|
||||
|
||||
@patch("libs.rate_limit._build_limiter")
|
||||
def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch):
|
||||
def test_enforce_bearer_rate_limit_disabled_when_limit_is_zero(mock_build, monkeypatch: pytest.MonkeyPatch):
|
||||
# 0 disables the limit — short-circuit before building/consulting a limiter.
|
||||
monkeypatch.setattr(
|
||||
"libs.rate_limit.LIMIT_BEARER_PER_TOKEN",
|
||||
|
||||
@ -17,6 +17,8 @@ from unittest.mock import Mock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -703,7 +705,7 @@ class TestDocumentSegmentIndexing:
|
||||
# Assert
|
||||
assert segment.hit_count == 5
|
||||
|
||||
def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch):
|
||||
def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test attachment source URLs use FILES_URL before falling back to CONSOLE_API_URL."""
|
||||
# Arrange
|
||||
segment = DocumentSegment(
|
||||
|
||||
@ -125,7 +125,9 @@ def test_rebuild_serialized_graph_files_without_lookup_preserves_scalar_values()
|
||||
assert rebuild_serialized_graph_files_without_lookup("plain-text") == "plain-text"
|
||||
|
||||
|
||||
def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup(monkeypatch) -> None:
|
||||
def test_build_file_from_stored_mapping_rebuilds_remote_urls_without_record_lookup(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
rebuilt_file = File(
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
|
||||
@ -2,6 +2,8 @@ import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
|
||||
|
||||
@ -11,7 +13,7 @@ def test_graph_dict_returns_empty_without_workflow_id() -> None:
|
||||
assert snippet.graph_dict == {}
|
||||
|
||||
|
||||
def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None:
|
||||
def test_graph_dict_loads_published_workflow_graph(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"id": "llm-1"}], "edges": []}))
|
||||
session = SimpleNamespace(get=Mock(return_value=workflow))
|
||||
monkeypatch.setattr("models.snippet.db.session", session)
|
||||
@ -21,7 +23,7 @@ def test_graph_dict_loads_published_workflow_graph(monkeypatch) -> None:
|
||||
session.get.assert_called_once()
|
||||
|
||||
|
||||
def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch) -> None:
|
||||
def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = SimpleNamespace(get=Mock(return_value=None))
|
||||
monkeypatch.setattr("models.snippet.db.session", session)
|
||||
snippet = CustomizedSnippet(workflow_id="missing-workflow")
|
||||
@ -36,7 +38,7 @@ def test_input_fields_list_parses_json_or_returns_empty() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_tags_returns_query_results_or_empty(monkeypatch) -> None:
|
||||
def test_tags_returns_query_results_or_empty(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
tags = [SimpleNamespace(id="tag-1")]
|
||||
session = SimpleNamespace(scalars=Mock(return_value=SimpleNamespace(all=Mock(return_value=tags))))
|
||||
monkeypatch.setattr("models.snippet.db.session", session)
|
||||
@ -48,7 +50,7 @@ def test_tags_returns_query_results_or_empty(monkeypatch) -> None:
|
||||
assert snippet.tags == []
|
||||
|
||||
|
||||
def test_account_properties_and_author_name(monkeypatch) -> None:
|
||||
def test_account_properties_and_author_name(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = SimpleNamespace(id="account-1", name="Ada")
|
||||
updated_account = SimpleNamespace(id="account-2", name="Grace")
|
||||
session = SimpleNamespace(
|
||||
|
||||
@ -1324,7 +1324,7 @@ class TestAgentAppBackingAgent:
|
||||
with pytest.raises(roster_service.AgentNotFoundError):
|
||||
service.get_agent_app_model(tenant_id="tenant-1", agent_id="agent-x")
|
||||
|
||||
def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch):
|
||||
def test_duplicate_agent_app_copies_app_config_and_active_soul(self, monkeypatch: pytest.MonkeyPatch):
|
||||
source_config = SimpleNamespace(
|
||||
opening_statement="hello",
|
||||
suggested_questions='["q1"]',
|
||||
@ -1463,7 +1463,7 @@ class TestAgentAppBackingAgent:
|
||||
assert target_agent.updated_by == "account-1"
|
||||
assert session.commits == 1
|
||||
|
||||
def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch):
|
||||
def test_duplicate_agent_app_inherits_webapp_access_mode(self, monkeypatch: pytest.MonkeyPatch):
|
||||
source_app = SimpleNamespace(
|
||||
id="source-app",
|
||||
tenant_id="tenant-1",
|
||||
@ -1522,7 +1522,7 @@ class TestAgentAppBackingAgent:
|
||||
assert duplicated is target_app
|
||||
assert access_mode_updates == [("target-app", "private")]
|
||||
|
||||
def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch):
|
||||
def test_duplicate_agent_app_falls_back_to_public_access_mode(self, monkeypatch: pytest.MonkeyPatch):
|
||||
source_app = SimpleNamespace(
|
||||
id="source-app",
|
||||
tenant_id="tenant-1",
|
||||
|
||||
@ -66,7 +66,7 @@ class TestFirecrawlAuth:
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock, auth_instance: FirecrawlAuth):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -97,7 +97,9 @@ class TestFirecrawlAuth:
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
def test_should_handle_http_errors(
|
||||
self, mock_post: MagicMock, status_code, error_message, auth_instance: FirecrawlAuth
|
||||
):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
@ -120,7 +122,13 @@ class TestFirecrawlAuth:
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
self,
|
||||
mock_post: MagicMock,
|
||||
status_code,
|
||||
response_text,
|
||||
has_json_error,
|
||||
expected_error_contains,
|
||||
auth_instance: FirecrawlAuth,
|
||||
):
|
||||
"""Test handling of unexpected errors with various response formats"""
|
||||
mock_response = MagicMock()
|
||||
@ -146,7 +154,9 @@ class TestFirecrawlAuth:
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
def test_should_handle_network_errors(
|
||||
self, mock_post: MagicMock, exception_type, exception_message, auth_instance: FirecrawlAuth
|
||||
):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
|
||||
@ -186,7 +196,7 @@ class TestFirecrawlAuth:
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post: MagicMock, auth_instance: FirecrawlAuth):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ class TestJinaAuth:
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post: MagicMock):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -54,7 +54,7 @@ class TestJinaAuth:
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
def test_should_handle_http_402_error(self, mock_post: MagicMock):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 402
|
||||
@ -100,7 +100,7 @@ class TestJinaAuth:
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
def test_should_handle_http_500_error(self, mock_post: MagicMock):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@ -115,7 +115,7 @@ class TestJinaAuth:
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post: MagicMock):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
@ -163,7 +163,7 @@ class TestJinaAuth:
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
def test_should_handle_network_errors(self, mock_post: MagicMock):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class _FakeWriteSession:
|
||||
|
||||
|
||||
class TestUpdateFeatures:
|
||||
def test_persists_new_app_model_config_version(self, monkeypatch):
|
||||
def test_persists_new_app_model_config_version(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session = _FakeWriteSession()
|
||||
monkeypatch.setattr(svc_mod.db, "session", session)
|
||||
app_model = SimpleNamespace(
|
||||
|
||||
@ -137,14 +137,14 @@ class ExternalDatasetServiceTestDataFactory:
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return ExternalDatasetServiceTestDataFactory
|
||||
return ExternalDatasetServiceTestDataFactory()
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetAPIs:
|
||||
"""Test get_external_knowledge_apis operations - comprehensive coverage."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_success_basic(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_success_basic(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful retrieval of external knowledge APIs with pagination."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -171,7 +171,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
mock_db.paginate.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_with_search_filter(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_with_search_filter(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test retrieval with search filter."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -195,7 +197,7 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert result_items[0].name == "Production API"
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_empty_results(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_empty_results(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test retrieval with no results."""
|
||||
# Arrange
|
||||
mock_pagination = MagicMock()
|
||||
@ -213,7 +215,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert result_total == 0
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_large_result_set(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_large_result_set(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test retrieval with large result set."""
|
||||
# Arrange
|
||||
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)]
|
||||
@ -233,7 +237,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert result_total == 100
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_pagination_last_page(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_pagination_last_page(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test last page pagination with partial results."""
|
||||
# Arrange
|
||||
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(95, 100)]
|
||||
@ -253,7 +259,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert result_total == 100
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_case_insensitive_search(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_case_insensitive_search(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test case-insensitive search functionality."""
|
||||
# Arrange
|
||||
apis = [
|
||||
@ -276,7 +284,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert result_total == 2
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_special_characters_search(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_special_characters_search(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test search with special characters."""
|
||||
# Arrange
|
||||
apis = [factory.create_external_knowledge_api_mock(name="API-v2.0 (beta)")]
|
||||
@ -295,7 +305,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert len(result_items) == 1
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_max_per_page_limit(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_max_per_page_limit(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that max_per_page limit is enforced."""
|
||||
# Arrange
|
||||
apis = [factory.create_external_knowledge_api_mock(api_id=f"api-{i}") for i in range(100)]
|
||||
@ -315,7 +327,9 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
assert call_args.kwargs["max_per_page"] == 100
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_apis_ordered_by_created_at_desc(self, mock_db, factory):
|
||||
def test_get_external_knowledge_apis_ordered_by_created_at_desc(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that results are ordered by created_at descending."""
|
||||
# Arrange
|
||||
apis = [
|
||||
@ -340,7 +354,7 @@ class TestExternalDatasetServiceGetAPIs:
|
||||
class TestExternalDatasetServiceValidateAPIList:
|
||||
"""Test validate_api_list operations."""
|
||||
|
||||
def test_validate_api_list_success_with_all_fields(self, factory):
|
||||
def test_validate_api_list_success_with_all_fields(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful validation with all required fields."""
|
||||
# Arrange
|
||||
api_settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"}
|
||||
@ -348,7 +362,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
# Act & Assert - should not raise
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_missing_endpoint(self, factory):
|
||||
def test_validate_api_list_missing_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when endpoint is missing."""
|
||||
# Arrange
|
||||
api_settings = {"api_key": "test-key"}
|
||||
@ -357,7 +371,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="endpoint is required"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_empty_endpoint(self, factory):
|
||||
def test_validate_api_list_empty_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when endpoint is empty string."""
|
||||
# Arrange
|
||||
api_settings = {"endpoint": "", "api_key": "test-key"}
|
||||
@ -366,7 +380,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="endpoint is required"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_missing_api_key(self, factory):
|
||||
def test_validate_api_list_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when API key is missing."""
|
||||
# Arrange
|
||||
api_settings = {"endpoint": "https://api.example.com"}
|
||||
@ -375,7 +389,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_empty_api_key(self, factory):
|
||||
def test_validate_api_list_empty_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when API key is empty string."""
|
||||
# Arrange
|
||||
api_settings = {"endpoint": "https://api.example.com", "api_key": ""}
|
||||
@ -384,7 +398,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_empty_dict(self, factory):
|
||||
def test_validate_api_list_empty_dict(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when settings are empty dict."""
|
||||
# Arrange
|
||||
api_settings = {}
|
||||
@ -393,7 +407,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="api list is empty"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_none_value(self, factory):
|
||||
def test_validate_api_list_none_value(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when settings are None."""
|
||||
# Arrange
|
||||
api_settings = None
|
||||
@ -402,7 +416,7 @@ class TestExternalDatasetServiceValidateAPIList:
|
||||
with pytest.raises(ValueError, match="api list is empty"):
|
||||
ExternalDatasetService.validate_api_list(api_settings)
|
||||
|
||||
def test_validate_api_list_with_extra_fields(self, factory):
|
||||
def test_validate_api_list_with_extra_fields(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation succeeds with extra fields present."""
|
||||
# Arrange
|
||||
api_settings = {
|
||||
@ -421,7 +435,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
|
||||
def test_create_external_knowledge_api_success_full(self, mock_check, mock_db, factory):
|
||||
def test_create_external_knowledge_api_success_full(
|
||||
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test successful creation with all fields."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -447,7 +463,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
|
||||
def test_create_external_knowledge_api_minimal_fields(self, mock_check, mock_db, factory):
|
||||
def test_create_external_knowledge_api_minimal_fields(
|
||||
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test creation with minimal required fields."""
|
||||
# Arrange
|
||||
args = {
|
||||
@ -463,7 +481,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
assert result.description == ""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_knowledge_api_missing_settings(self, mock_db, factory):
|
||||
def test_create_external_knowledge_api_missing_settings(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test creation fails when settings are missing."""
|
||||
# Arrange
|
||||
args = {"name": "Test API", "description": "Test"}
|
||||
@ -473,7 +493,7 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
ExternalDatasetService.create_external_knowledge_api("tenant-123", "user-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_knowledge_api_none_settings(self, mock_db, factory):
|
||||
def test_create_external_knowledge_api_none_settings(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test creation fails when settings are explicitly None."""
|
||||
# Arrange
|
||||
args = {"name": "Test API", "settings": None}
|
||||
@ -484,7 +504,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
|
||||
def test_create_external_knowledge_api_settings_json_serialization(self, mock_check, mock_db, factory):
|
||||
def test_create_external_knowledge_api_settings_json_serialization(
|
||||
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that settings are properly JSON serialized."""
|
||||
# Arrange
|
||||
settings = {
|
||||
@ -504,7 +526,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
|
||||
def test_create_external_knowledge_api_unicode_handling(self, mock_check, mock_db, factory):
|
||||
def test_create_external_knowledge_api_unicode_handling(
|
||||
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test proper handling of Unicode characters in name and description."""
|
||||
# Arrange
|
||||
args = {
|
||||
@ -522,7 +546,9 @@ class TestExternalDatasetServiceCreateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.check_endpoint_and_api_key")
|
||||
def test_create_external_knowledge_api_long_description(self, mock_check, mock_db, factory):
|
||||
def test_create_external_knowledge_api_long_description(
|
||||
self, mock_check, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test creation with very long description."""
|
||||
# Arrange
|
||||
long_description = "A" * 1000
|
||||
@ -544,7 +570,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
"""Test check_endpoint_and_api_key operations - extensive coverage."""
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_success_https(self, mock_proxy, factory):
|
||||
def test_check_endpoint_success_https(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful validation with HTTPS endpoint."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -558,7 +584,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
mock_proxy.post.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_success_http(self, mock_proxy, factory):
|
||||
def test_check_endpoint_success_http(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful validation with HTTP endpoint."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "http://api.example.com", "api_key": "test-key"}
|
||||
@ -570,7 +596,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
# Act & Assert - should not raise
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_missing_endpoint_key(self, factory):
|
||||
def test_check_endpoint_missing_endpoint_key(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when endpoint key is missing."""
|
||||
# Arrange
|
||||
settings = {"api_key": "test-key"}
|
||||
@ -579,7 +605,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="endpoint is required"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_empty_endpoint_string(self, factory):
|
||||
def test_check_endpoint_empty_endpoint_string(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when endpoint is empty string."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "", "api_key": "test-key"}
|
||||
@ -588,7 +614,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="endpoint is required"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_whitespace_endpoint(self, factory):
|
||||
def test_check_endpoint_whitespace_endpoint(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when endpoint is only whitespace."""
|
||||
# Arrange
|
||||
settings = {"endpoint": " ", "api_key": "test-key"}
|
||||
@ -597,7 +623,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="invalid endpoint"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_missing_api_key_key(self, factory):
|
||||
def test_check_endpoint_missing_api_key_key(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when api_key key is missing."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com"}
|
||||
@ -606,7 +632,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_empty_api_key_string(self, factory):
|
||||
def test_check_endpoint_empty_api_key_string(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when api_key is empty string."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": ""}
|
||||
@ -615,7 +641,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_no_scheme_url(self, factory):
|
||||
def test_check_endpoint_no_scheme_url(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails for URL without http:// or https://."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "api.example.com", "api_key": "test-key"}
|
||||
@ -624,7 +650,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="invalid endpoint.*must start with http"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_invalid_scheme(self, factory):
|
||||
def test_check_endpoint_invalid_scheme(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails for URL with invalid scheme."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "ftp://api.example.com", "api_key": "test-key"}
|
||||
@ -633,7 +659,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="failed to connect to the endpoint"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_no_netloc(self, factory):
|
||||
def test_check_endpoint_no_netloc(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails for URL without network location."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "http://", "api_key": "test-key"}
|
||||
@ -642,7 +668,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
with pytest.raises(ValueError, match="invalid endpoint"):
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
def test_check_endpoint_malformed_url(self, factory):
|
||||
def test_check_endpoint_malformed_url(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails for malformed URL."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https:///invalid", "api_key": "test-key"}
|
||||
@ -652,7 +678,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_connection_timeout(self, mock_proxy, factory):
|
||||
def test_check_endpoint_connection_timeout(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails on connection timeout."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -663,7 +689,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_network_error(self, mock_proxy, factory):
|
||||
def test_check_endpoint_network_error(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails on network error."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -674,7 +700,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory):
|
||||
def test_check_endpoint_502_bad_gateway(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails with 502 Bad Gateway."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -688,7 +714,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_404_not_found(self, mock_proxy, factory):
|
||||
def test_check_endpoint_404_not_found(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails with 404 Not Found."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -702,7 +728,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_403_forbidden(self, mock_proxy, factory):
|
||||
def test_check_endpoint_403_forbidden(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails with 403 Forbidden (auth failure)."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "wrong-key"}
|
||||
@ -716,7 +742,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory):
|
||||
def test_check_endpoint_other_4xx_codes_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test that other 4xx codes don't raise exceptions."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -730,7 +756,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory):
|
||||
def test_check_endpoint_5xx_codes_except_502_pass(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test that 5xx codes except 502 don't raise exceptions."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key"}
|
||||
@ -744,7 +770,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_with_port_number(self, mock_proxy, factory):
|
||||
def test_check_endpoint_with_port_number(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation with endpoint including port number."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com:8443", "api_key": "test-key"}
|
||||
@ -757,7 +783,7 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_with_path(self, mock_proxy, factory):
|
||||
def test_check_endpoint_with_path(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation with endpoint including path."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com/v1/api", "api_key": "test-key"}
|
||||
@ -773,7 +799,9 @@ class TestExternalDatasetServiceCheckEndpoint:
|
||||
assert "/retrieval" in call_args[0][0]
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_check_endpoint_authorization_header_format(self, mock_proxy, factory):
|
||||
def test_check_endpoint_authorization_header_format(
|
||||
self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that Authorization header is properly formatted."""
|
||||
# Arrange
|
||||
settings = {"endpoint": "https://api.example.com", "api_key": "test-key-123"}
|
||||
@ -795,7 +823,7 @@ class TestExternalDatasetServiceGetAPI:
|
||||
"""Test get_external_knowledge_api operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_api_success(self, mock_db, factory):
|
||||
def test_get_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful retrieval of external knowledge API."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -811,7 +839,7 @@ class TestExternalDatasetServiceGetAPI:
|
||||
assert result.id == api_id
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
def test_get_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -826,7 +854,9 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
|
||||
@patch("services.external_knowledge_service.naive_utc_now")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_update_external_knowledge_api_success_all_fields(self, mock_db, mock_now, factory):
|
||||
def test_update_external_knowledge_api_success_all_fields(
|
||||
self, mock_db, mock_now, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test successful update with all fields."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -856,7 +886,9 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_update_external_knowledge_api_preserve_hidden_api_key(self, mock_db, factory):
|
||||
def test_update_external_knowledge_api_preserve_hidden_api_key(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that hidden API key is preserved from existing settings."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -883,7 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
assert settings["api_key"] == "original-secret-key"
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
def test_update_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -895,7 +927,9 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
def test_update_external_knowledge_api_tenant_mismatch(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -907,7 +941,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
ExternalDatasetService.update_external_knowledge_api("wrong-tenant", "user-123", "api-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_update_external_knowledge_api_name_only(self, mock_db, factory):
|
||||
def test_update_external_knowledge_api_name_only(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test updating only the name field."""
|
||||
# Arrange
|
||||
existing_api = factory.create_external_knowledge_api_mock(
|
||||
@ -930,7 +964,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
"""Test delete_external_knowledge_api operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_delete_external_knowledge_api_success(self, mock_db, factory):
|
||||
def test_delete_external_knowledge_api_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful deletion of external knowledge API."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -948,7 +982,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
def test_delete_external_knowledge_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -958,7 +992,9 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-123", "api-123")
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
def test_delete_external_knowledge_api_tenant_mismatch(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -972,7 +1008,9 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
"""Test external_knowledge_api_use_check operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_external_knowledge_api_use_check_in_use_single(self, mock_db, factory):
|
||||
def test_external_knowledge_api_use_check_in_use_single(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test API use check when API has one binding."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -989,7 +1027,9 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
|
||||
def test_external_knowledge_api_use_check_in_use_multiple(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test API use check with multiple bindings."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -1005,7 +1045,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
assert count == 10
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory):
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test API use check when API is not in use."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
@ -1025,7 +1065,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
"""Test get_external_knowledge_binding_with_dataset_id operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_binding_success(self, mock_db, factory):
|
||||
def test_get_external_knowledge_binding_success(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful retrieval of external knowledge binding."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -1043,7 +1083,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
assert result.tenant_id == tenant_id
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
|
||||
def test_get_external_knowledge_binding_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when binding is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -1057,7 +1097,9 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
"""Test document_create_args_validate operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_document_create_args_validate_success_all_params(self, mock_db, factory):
|
||||
def test_document_create_args_validate_success_all_params(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test successful validation with all required parameters."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -1081,7 +1123,9 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_document_create_args_validate_missing_required_param(self, mock_db, factory):
|
||||
def test_document_create_args_validate_missing_required_param(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test validation fails when required parameter is missing."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -1100,7 +1144,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
ExternalDatasetService.document_create_args_validate(tenant_id, api_id, process_parameter)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
|
||||
def test_document_create_args_validate_api_not_found(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test validation fails when API is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -1110,7 +1154,9 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_document_create_args_validate_no_custom_parameters(self, mock_db, factory):
|
||||
def test_document_create_args_validate_no_custom_parameters(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test validation succeeds when no custom parameters defined."""
|
||||
# Arrange
|
||||
settings = {}
|
||||
@ -1122,7 +1168,9 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_document_create_args_validate_optional_params_not_required(self, mock_db, factory):
|
||||
def test_document_create_args_validate_optional_params_not_required(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that optional parameters don't cause validation failure."""
|
||||
# Arrange
|
||||
settings = {
|
||||
@ -1146,7 +1194,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
"""Test process_external_api operations - comprehensive HTTP method coverage."""
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_get_request(self, mock_proxy, factory):
|
||||
def test_process_external_api_get_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing GET request."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="get")
|
||||
@ -1162,7 +1210,9 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
mock_proxy.get.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_post_request_with_data(self, mock_proxy, factory):
|
||||
def test_process_external_api_post_request_with_data(
|
||||
self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test processing POST request with data."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="post", params={"key": "value", "data": "test"})
|
||||
@ -1180,7 +1230,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
assert "data" in call_kwargs
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_put_request(self, mock_proxy, factory):
|
||||
def test_process_external_api_put_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing PUT request."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="put")
|
||||
@ -1196,7 +1246,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
mock_proxy.put.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_delete_request(self, mock_proxy, factory):
|
||||
def test_process_external_api_delete_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing DELETE request."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="delete")
|
||||
@ -1212,7 +1262,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
mock_proxy.delete.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_patch_request(self, mock_proxy, factory):
|
||||
def test_process_external_api_patch_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing PATCH request."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="patch")
|
||||
@ -1228,7 +1278,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
mock_proxy.patch.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_head_request(self, mock_proxy, factory):
|
||||
def test_process_external_api_head_request(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing HEAD request."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="head")
|
||||
@ -1243,7 +1293,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
assert result == mock_response
|
||||
mock_proxy.head.assert_called_once()
|
||||
|
||||
def test_process_external_api_invalid_method(self, factory):
|
||||
def test_process_external_api_invalid_method(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error for invalid HTTP method."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="INVALID")
|
||||
@ -1253,7 +1303,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
ExternalDatasetService.process_external_api(settings, None)
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_with_files(self, mock_proxy, factory):
|
||||
def test_process_external_api_with_files(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test processing request with file uploads."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="post")
|
||||
@ -1272,7 +1322,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
assert call_kwargs["files"] == files
|
||||
|
||||
@patch("services.external_knowledge_service.ssrf_proxy")
|
||||
def test_process_external_api_follow_redirects(self, mock_proxy, factory):
|
||||
def test_process_external_api_follow_redirects(self, mock_proxy, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test that follow_redirects is enabled."""
|
||||
# Arrange
|
||||
settings = factory.create_api_setting_mock(request_method="get")
|
||||
@ -1291,7 +1341,7 @@ class TestExternalDatasetServiceProcessAPI:
|
||||
class TestExternalDatasetServiceAssemblingHeaders:
|
||||
"""Test assembling_headers operations - comprehensive authorization coverage."""
|
||||
|
||||
def test_assembling_headers_bearer_token(self, factory):
|
||||
def test_assembling_headers_bearer_token(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers with Bearer token."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="bearer", api_key="secret-key-123")
|
||||
@ -1302,7 +1352,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
# Assert
|
||||
assert result["Authorization"] == "Bearer secret-key-123"
|
||||
|
||||
def test_assembling_headers_basic_auth(self, factory):
|
||||
def test_assembling_headers_basic_auth(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers with Basic authentication."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="basic", api_key="credentials")
|
||||
@ -1313,7 +1363,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
# Assert
|
||||
assert result["Authorization"] == "Basic credentials"
|
||||
|
||||
def test_assembling_headers_custom_auth(self, factory):
|
||||
def test_assembling_headers_custom_auth(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers with custom authentication."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="custom", api_key="custom-token")
|
||||
@ -1324,7 +1374,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
# Assert
|
||||
assert result["Authorization"] == "custom-token"
|
||||
|
||||
def test_assembling_headers_custom_header_name(self, factory):
|
||||
def test_assembling_headers_custom_header_name(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers with custom header name."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key-123", header="X-API-Key")
|
||||
@ -1336,7 +1386,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
assert result["X-API-Key"] == "Bearer key-123"
|
||||
assert "Authorization" not in result
|
||||
|
||||
def test_assembling_headers_with_existing_headers(self, factory):
|
||||
def test_assembling_headers_with_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers preserves existing headers."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key")
|
||||
@ -1355,7 +1405,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
assert result["X-Custom"] == "value"
|
||||
assert result["User-Agent"] == "TestAgent/1.0"
|
||||
|
||||
def test_assembling_headers_empty_existing_headers(self, factory):
|
||||
def test_assembling_headers_empty_existing_headers(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test assembling headers with empty existing headers dict."""
|
||||
# Arrange
|
||||
authorization = factory.create_authorization_mock(token_type="bearer", api_key="key")
|
||||
@ -1368,7 +1418,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
assert result["Authorization"] == "Bearer key"
|
||||
assert len(result) == 1
|
||||
|
||||
def test_assembling_headers_missing_api_key(self, factory):
|
||||
def test_assembling_headers_missing_api_key(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when API key is missing."""
|
||||
# Arrange
|
||||
config = AuthorizationConfig(api_key=None, type="bearer", header="Authorization")
|
||||
@ -1378,7 +1428,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.assembling_headers(authorization)
|
||||
|
||||
def test_assembling_headers_missing_config(self, factory):
|
||||
def test_assembling_headers_missing_config(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when config is missing."""
|
||||
# Arrange
|
||||
authorization = Authorization(type="api-key", config=None)
|
||||
@ -1387,7 +1437,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
with pytest.raises(ValueError, match="authorization config is required"):
|
||||
ExternalDatasetService.assembling_headers(authorization)
|
||||
|
||||
def test_assembling_headers_default_header_name(self, factory):
|
||||
def test_assembling_headers_default_header_name(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test that default header name is Authorization when not specified."""
|
||||
# Arrange
|
||||
config = AuthorizationConfig(api_key="key", type="bearer", header=None)
|
||||
@ -1403,7 +1453,7 @@ class TestExternalDatasetServiceAssemblingHeaders:
|
||||
class TestExternalDatasetServiceGetSettings:
|
||||
"""Test get_external_knowledge_api_settings operations."""
|
||||
|
||||
def test_get_external_knowledge_api_settings_success(self, factory):
|
||||
def test_get_external_knowledge_api_settings_success(self, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful parsing of API settings."""
|
||||
# Arrange
|
||||
settings = {
|
||||
@ -1428,7 +1478,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
"""Test create_external_dataset operations."""
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_dataset_success_full(self, mock_db, factory):
|
||||
def test_create_external_dataset_success_full(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test successful creation of external dataset with all fields."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -1457,7 +1507,9 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_dataset_duplicate_name_error(self, mock_db, factory):
|
||||
def test_create_external_dataset_duplicate_name_error(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when dataset name already exists."""
|
||||
# Arrange
|
||||
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
|
||||
@ -1471,7 +1523,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
|
||||
def test_create_external_dataset_api_not_found_error(self, mock_db, factory: ExternalDatasetServiceTestDataFactory):
|
||||
"""Test error when external knowledge API is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.side_effect = [None, None]
|
||||
@ -1483,7 +1535,9 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_dataset_missing_knowledge_id_error(self, mock_db, factory):
|
||||
def test_create_external_dataset_missing_knowledge_id_error(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when external_knowledge_id is missing."""
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
@ -1497,7 +1551,9 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
ExternalDatasetService.create_external_dataset("tenant-123", "user-123", args)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_create_external_dataset_missing_api_id_error(self, mock_db, factory):
|
||||
def test_create_external_dataset_missing_api_id_error(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when external_knowledge_api_id is missing."""
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
@ -1516,7 +1572,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_success_with_results(self, mock_db, mock_process, factory):
|
||||
def test_fetch_external_knowledge_retrieval_success_with_results(
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test successful external knowledge retrieval with results."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -1553,7 +1611,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
assert result[1]["score"] == 0.8
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_error(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
@ -1563,7 +1623,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory):
|
||||
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(
|
||||
self, mock_db, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test error when a binding points to an API template outside the dataset tenant."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
@ -1575,7 +1637,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory):
|
||||
def test_fetch_external_knowledge_retrieval_empty_results(
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test retrieval with empty results."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
@ -1598,7 +1662,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_with_score_threshold(self, mock_db, mock_process, factory):
|
||||
def test_fetch_external_knowledge_retrieval_with_score_threshold(
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test retrieval with score threshold enabled."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
@ -1630,7 +1696,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test that non-200 status code raises Exception with response text."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
@ -1665,7 +1733,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
|
||||
self, mock_db, mock_process, factory, status_code, error_message
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory, status_code, error_message
|
||||
):
|
||||
"""Test that various error status codes raise exceptions with response text."""
|
||||
# Arrange
|
||||
@ -1690,7 +1758,9 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
|
||||
def test_fetch_external_knowledge_retrieval_empty_response_text(
|
||||
self, mock_db, mock_process, factory: ExternalDatasetServiceTestDataFactory
|
||||
):
|
||||
"""Test exception with empty response text."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
|
||||
@ -259,7 +259,7 @@ def test_resolve_form_inputs_uses_runtime_select_options(
|
||||
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(
|
||||
sample_form_record, mock_session_factory, mocker: MockerFixture
|
||||
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
|
||||
):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
@ -318,7 +318,7 @@ def test_submit_form_by_token_enqueues_agent_app_resume_for_conversation_form(
|
||||
|
||||
|
||||
def test_submit_form_by_token_skips_enqueue_for_delivery_test(
|
||||
sample_form_record, mock_session_factory, mocker: MockerFixture
|
||||
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
|
||||
):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
@ -343,7 +343,7 @@ def test_submit_form_by_token_skips_enqueue_for_delivery_test(
|
||||
|
||||
|
||||
def test_submit_form_by_token_passes_submission_user_id(
|
||||
sample_form_record, mock_session_factory, mocker: MockerFixture
|
||||
sample_form_record: HumanInputFormRecord, mock_session_factory, mocker: MockerFixture
|
||||
):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
@ -528,7 +528,7 @@ def test_validate_human_input_submission_rejects_invalid_select_and_file_payload
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
|
||||
def test_form_properties(sample_form_record):
|
||||
def test_form_properties(sample_form_record: HumanInputFormRecord):
|
||||
form = Form(sample_form_record)
|
||||
assert form.id == "form-id"
|
||||
assert form.workflow_run_id == "workflow-run-id"
|
||||
|
||||
@ -90,7 +90,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 01: No user provided
|
||||
def test_pagination_by_first_id_no_user(self, factory):
|
||||
def test_pagination_by_first_id_no_user(self, factory: TestMessageServiceFactory):
|
||||
"""Test pagination returns empty result when no user is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -111,7 +111,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 02: No conversation_id provided
|
||||
def test_pagination_by_first_id_no_conversation(self, factory):
|
||||
def test_pagination_by_first_id_no_conversation(self, factory: TestMessageServiceFactory):
|
||||
"""Test pagination returns empty result when no conversation_id is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -137,7 +137,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_desc(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test basic pagination without first_id in descending order."""
|
||||
# Arrange
|
||||
@ -180,7 +180,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_asc(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test basic pagination without first_id in ascending order."""
|
||||
# Arrange
|
||||
@ -222,7 +222,9 @@ class TestMessageServicePaginationByFirstId:
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, mock_create_repo, factory):
|
||||
def test_pagination_by_first_id_with_first_id(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test pagination with first_id to get messages before a specific message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -265,7 +267,9 @@ class TestMessageServicePaginationByFirstId:
|
||||
# Test 06: First message not found
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_first_message_not_exists(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_first_message_not_exists(
|
||||
self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test error handling when first_id doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -290,7 +294,9 @@ class TestMessageServicePaginationByFirstId:
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, mock_create_repo, factory):
|
||||
def test_pagination_by_first_id_has_more_true(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test has_more flag is True when results exceed limit."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -327,7 +333,9 @@ class TestMessageServicePaginationByFirstId:
|
||||
# Test 08: Empty conversation
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_empty_conversation(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_empty_conversation(
|
||||
self, mock_conversation_service, mock_db, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test pagination with conversation that has no messages."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -370,7 +378,7 @@ class TestMessageServicePaginationByLastId:
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 09: No user provided
|
||||
def test_pagination_by_last_id_no_user(self, factory):
|
||||
def test_pagination_by_last_id_no_user(self, factory: TestMessageServiceFactory):
|
||||
"""Test pagination returns empty result when no user is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -391,7 +399,7 @@ class TestMessageServicePaginationByLastId:
|
||||
|
||||
# Test 10: Basic pagination without last_id
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_without_last_id(self, mock_db, factory):
|
||||
def test_pagination_by_last_id_without_last_id(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test basic pagination without last_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -422,7 +430,7 @@ class TestMessageServicePaginationByLastId:
|
||||
|
||||
# Test 11: Pagination with last_id
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_last_id(self, mock_db, factory):
|
||||
def test_pagination_by_last_id_with_last_id(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test pagination with last_id to get messages after a specific message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -459,7 +467,7 @@ class TestMessageServicePaginationByLastId:
|
||||
|
||||
# Test 12: Last message not found
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory):
|
||||
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test error handling when last_id doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -479,7 +487,9 @@ class TestMessageServicePaginationByLastId:
|
||||
# Test 13: Pagination with conversation_id filter
|
||||
@patch("services.message_service.ConversationService")
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_conversation_filter(self, mock_db, mock_conversation_service, factory):
|
||||
def test_pagination_by_last_id_with_conversation_filter(
|
||||
self, mock_db, mock_conversation_service, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test pagination filtered by conversation_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -515,7 +525,7 @@ class TestMessageServicePaginationByLastId:
|
||||
|
||||
# Test 14: Pagination with include_ids filter
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory):
|
||||
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test pagination filtered by include_ids."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -545,7 +555,7 @@ class TestMessageServicePaginationByLastId:
|
||||
|
||||
# Test 15: Has_more flag when results exceed limit
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_has_more_true(self, mock_db, factory):
|
||||
def test_pagination_by_last_id_has_more_true(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test has_more flag is True when results exceed limit."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -592,7 +602,7 @@ class TestMessageServiceUtilities:
|
||||
|
||||
# Test 17: attach_message_extra_contents with messages
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory):
|
||||
def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory: TestMessageServiceFactory):
|
||||
"""Test attach_message_extra_contents correctly attaches content."""
|
||||
# Arrange
|
||||
messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")]
|
||||
@ -618,7 +628,9 @@ class TestMessageServiceUtilities:
|
||||
|
||||
# Test 18: attach_message_extra_contents with index out of bounds
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory):
|
||||
def test_attach_message_extra_contents_index_out_of_bounds(
|
||||
self, mock_create_repo, factory: TestMessageServiceFactory
|
||||
):
|
||||
"""Test attach_message_extra_contents handles missing content lists."""
|
||||
# Arrange
|
||||
messages = [factory.create_message_mock(message_id="msg-1")]
|
||||
@ -659,7 +671,7 @@ class TestMessageServiceGetMessage:
|
||||
|
||||
# Test 20: get_message success for EndUser
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_end_user_success(self, mock_db, factory):
|
||||
def test_get_message_end_user_success(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test get_message returns message for EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -676,7 +688,7 @@ class TestMessageServiceGetMessage:
|
||||
|
||||
# Test 21: get_message success for Account (Admin)
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_account_success(self, mock_db, factory):
|
||||
def test_get_message_account_success(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test get_message returns message for Account."""
|
||||
# Arrange
|
||||
from models import Account
|
||||
@ -696,7 +708,7 @@ class TestMessageServiceGetMessage:
|
||||
|
||||
# Test 22: get_message not found
|
||||
@patch("services.message_service.db")
|
||||
def test_get_message_not_found(self, mock_db, factory):
|
||||
def test_get_message_not_found(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test get_message raises MessageNotExistsError when not found."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -720,7 +732,7 @@ class TestMessageServiceFeedback:
|
||||
# Test 23: create_feedback - new feedback for EndUser
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory):
|
||||
def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test creating new feedback for an end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -748,7 +760,7 @@ class TestMessageServiceFeedback:
|
||||
# Test 24: create_feedback - update feedback for Account
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_update_account(self, mock_get_message, mock_db, factory):
|
||||
def test_create_feedback_update_account(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test updating existing feedback for an account."""
|
||||
# Arrange
|
||||
from models import Account, MessageFeedback
|
||||
@ -779,7 +791,7 @@ class TestMessageServiceFeedback:
|
||||
# Test 25: create_feedback - delete feedback (rating is None)
|
||||
@patch("services.message_service.db")
|
||||
@patch.object(MessageService, "get_message")
|
||||
def test_create_feedback_delete(self, mock_get_message, mock_db, factory):
|
||||
def test_create_feedback_delete(self, mock_get_message, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test deleting feedback by passing rating=None."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -805,7 +817,7 @@ class TestMessageServiceFeedback:
|
||||
|
||||
# Test 26: get_all_messages_feedbacks
|
||||
@patch("services.message_service.db")
|
||||
def test_get_all_messages_feedbacks(self, mock_db, factory):
|
||||
def test_get_all_messages_feedbacks(self, mock_db, factory: TestMessageServiceFactory):
|
||||
"""Test get_all_messages_feedbacks returns list of dicts."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -830,7 +842,7 @@ class TestMessageServiceSuggestedQuestions:
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 27: get_suggested_questions_after_answer - user is None
|
||||
def test_get_suggested_questions_user_none(self, factory):
|
||||
def test_get_suggested_questions_user_none(self, factory: TestMessageServiceFactory):
|
||||
app = factory.create_app_mock()
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
@ -856,7 +868,7 @@ class TestMessageServiceSuggestedQuestions:
|
||||
mock_config_manager,
|
||||
mock_workflow_service,
|
||||
mock_model_manager,
|
||||
factory,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test successful suggested questions generation in Advanced Chat mode."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -896,14 +908,14 @@ class TestMessageServiceSuggestedQuestions:
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_chat_app_success(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_model_manager,
|
||||
mock_db,
|
||||
factory,
|
||||
mock_conversation_service: MagicMock,
|
||||
mock_get_message: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_llm_gen: MagicMock,
|
||||
mock_memory: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test successful suggested questions generation in basic Chat mode."""
|
||||
# Arrange
|
||||
@ -942,14 +954,14 @@ class TestMessageServiceSuggestedQuestions:
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_chat_app_uses_frontend_model_and_prompt(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_model_manager,
|
||||
mock_db,
|
||||
factory,
|
||||
mock_conversation_service: MagicMock,
|
||||
mock_get_message: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_llm_gen: MagicMock,
|
||||
mock_memory: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test suggested question generation uses frontend configured model and prompt."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -1015,14 +1027,14 @@ class TestMessageServiceSuggestedQuestions:
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_chat_app_invalid_frontend_model_fallback_to_default(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_model_manager,
|
||||
mock_db,
|
||||
factory,
|
||||
mock_conversation_service: MagicMock,
|
||||
mock_get_message: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_llm_gen: MagicMock,
|
||||
mock_memory: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test invalid frontend configured model falls back to tenant default model."""
|
||||
app = factory.create_app_mock(mode=AppMode.CHAT)
|
||||
@ -1066,14 +1078,14 @@ class TestMessageServiceSuggestedQuestions:
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_chat_app_uses_compatible_override_model_config(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_trace_manager,
|
||||
mock_llm_gen,
|
||||
mock_memory,
|
||||
mock_model_manager,
|
||||
mock_db,
|
||||
factory,
|
||||
mock_conversation_service: MagicMock,
|
||||
mock_get_message: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_llm_gen: MagicMock,
|
||||
mock_memory: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test legacy override configs are normalized before suggested questions reads them."""
|
||||
app = factory.create_app_mock(mode=AppMode.CHAT)
|
||||
@ -1174,7 +1186,12 @@ class TestMessageServiceSuggestedQuestions:
|
||||
@patch.object(MessageService, "get_message")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_get_suggested_questions_disabled_error(
|
||||
self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_get_message,
|
||||
mock_config_manager,
|
||||
mock_workflow_service,
|
||||
factory: TestMessageServiceFactory,
|
||||
):
|
||||
"""Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled."""
|
||||
# Arrange
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -130,7 +130,7 @@ class TestRagPipelineTaskProxy:
|
||||
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
def test_features_property(self, mock_feature_service: MagicMock):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
|
||||
@ -149,7 +149,7 @@ class TestRagPipelineTaskProxy:
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
|
||||
def test_upload_invoke_entities(self, mock_db: MagicMock, mock_file_service_class: MagicMock):
|
||||
"""Test _upload_invoke_entities method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -181,7 +181,9 @@ class TestRagPipelineTaskProxy:
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
|
||||
def test_upload_invoke_entities_with_multiple_entities(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock
|
||||
):
|
||||
"""Test _upload_invoke_entities method with multiple entities."""
|
||||
# Arrange
|
||||
entities = [
|
||||
@ -209,7 +211,7 @@ class TestRagPipelineTaskProxy:
|
||||
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
def test_send_to_direct_queue(self, mock_task: MagicMock):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -229,7 +231,7 @@ class TestRagPipelineTaskProxy:
|
||||
)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task: MagicMock):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -248,7 +250,7 @@ class TestRagPipelineTaskProxy:
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task: MagicMock):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -271,7 +273,7 @@ class TestRagPipelineTaskProxy:
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
def test_send_to_default_tenant_queue(self, mock_task: MagicMock):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -285,7 +287,7 @@ class TestRagPipelineTaskProxy:
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
def test_send_to_priority_tenant_queue(self, mock_task: MagicMock):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -299,7 +301,7 @@ class TestRagPipelineTaskProxy:
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
def test_send_to_priority_direct_queue(self, mock_task: MagicMock):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -315,7 +317,9 @@ class TestRagPipelineTaskProxy:
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
|
||||
):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
@ -365,7 +369,9 @@ class TestRagPipelineTaskProxy:
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
def test_dispatch_with_billing_disabled(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
|
||||
):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
@ -386,7 +392,7 @@ class TestRagPipelineTaskProxy:
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db: MagicMock, mock_file_service_class: MagicMock):
|
||||
"""Test _dispatch method when upload_file_id is empty."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
@ -404,7 +410,9 @@ class TestRagPipelineTaskProxy:
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
def test_dispatch_edge_case_empty_plan(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
|
||||
):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
@ -426,7 +434,9 @@ class TestRagPipelineTaskProxy:
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
def test_dispatch_edge_case_none_plan(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
|
||||
):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
@ -448,7 +458,9 @@ class TestRagPipelineTaskProxy:
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
def test_delay_method(
|
||||
self, mock_db: MagicMock, mock_file_service_class: MagicMock, mock_feature_service: MagicMock
|
||||
):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
|
||||
@ -73,7 +73,7 @@ def test_import_snippet_rejects_invalid_yaml_url_scheme() -> None:
|
||||
assert result.error == "Invalid URL scheme, only http and https are allowed"
|
||||
|
||||
|
||||
def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) -> None:
|
||||
def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SnippetDslService(session=SimpleNamespace())
|
||||
monkeypatch.setattr(
|
||||
"services.snippet_dsl_service.ssrf_proxy.get",
|
||||
@ -90,7 +90,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_fails(monkeypatch) ->
|
||||
assert result.error == "Failed to fetch YAML from URL: 404"
|
||||
|
||||
|
||||
def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None:
|
||||
def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SnippetDslService(session=SimpleNamespace())
|
||||
monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3)
|
||||
monkeypatch.setattr(
|
||||
@ -108,7 +108,7 @@ def test_import_snippet_rejects_oversized_yaml_url_content(monkeypatch) -> None:
|
||||
assert "YAML content size exceeds maximum limit" in result.error
|
||||
|
||||
|
||||
def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) -> None:
|
||||
def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SnippetDslService(session=SimpleNamespace())
|
||||
monkeypatch.setattr(
|
||||
"services.snippet_dsl_service.ssrf_proxy.get",
|
||||
@ -125,7 +125,7 @@ def test_import_snippet_returns_failed_when_yaml_url_fetch_raises(monkeypatch) -
|
||||
assert result.error == "Failed to fetch YAML from URL: network down"
|
||||
|
||||
|
||||
def test_import_snippet_rejects_oversized_yaml_content(monkeypatch) -> None:
|
||||
def test_import_snippet_rejects_oversized_yaml_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SnippetDslService(session=SimpleNamespace())
|
||||
monkeypatch.setattr("services.snippet_dsl_service.DSL_MAX_SIZE", 3)
|
||||
|
||||
|
||||
@ -529,7 +529,7 @@ def test_delete_snippet_removes_related_records() -> None:
|
||||
session.delete.assert_called_once_with(snippet)
|
||||
|
||||
|
||||
def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> None:
|
||||
def test_delete_draft_variable_files_removes_storage_objects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
@ -554,7 +554,7 @@ def test_delete_draft_variable_files_removes_storage_objects(monkeypatch) -> Non
|
||||
assert "workflow_draft_variable_files" in executed_sql
|
||||
|
||||
|
||||
def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch) -> None:
|
||||
def test_delete_archived_workflow_run_files_removes_prefixed_objects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from configs import dify_config
|
||||
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
|
||||
@ -30,10 +30,10 @@ class TestWorkflowGeneratorService:
|
||||
@patch("services.workflow_generator_service.format_tool_catalogue")
|
||||
def test_forwards_model_instance_and_catalogue_text_to_generator(
|
||||
self,
|
||||
mock_format_catalogue,
|
||||
mock_build_catalogue,
|
||||
mock_model_manager,
|
||||
mock_workflow_generator,
|
||||
mock_format_catalogue: MagicMock,
|
||||
mock_build_catalogue: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_workflow_generator: MagicMock,
|
||||
):
|
||||
"""Happy path: model_instance + catalogue text + payload flow through."""
|
||||
# Arrange
|
||||
@ -110,10 +110,10 @@ class TestWorkflowGeneratorService:
|
||||
@patch("services.workflow_generator_service.format_tool_catalogue")
|
||||
def test_defaults_ideal_output_to_empty_string(
|
||||
self,
|
||||
mock_format_catalogue,
|
||||
mock_build_catalogue,
|
||||
mock_model_manager,
|
||||
mock_workflow_generator,
|
||||
mock_format_catalogue: MagicMock,
|
||||
mock_build_catalogue: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_workflow_generator: MagicMock,
|
||||
):
|
||||
"""Callers can omit ideal_output; the runner should still receive ""."""
|
||||
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()
|
||||
@ -142,10 +142,10 @@ class TestWorkflowGeneratorService:
|
||||
@patch("services.workflow_generator_service.format_tool_catalogue")
|
||||
def test_forwards_current_graph_for_refine(
|
||||
self,
|
||||
mock_format_catalogue,
|
||||
mock_build_catalogue,
|
||||
mock_model_manager,
|
||||
mock_workflow_generator,
|
||||
mock_format_catalogue: MagicMock,
|
||||
mock_build_catalogue: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_workflow_generator: MagicMock,
|
||||
):
|
||||
"""The cmd+k `/refine` path passes the existing draft graph through to the runner."""
|
||||
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()
|
||||
@ -175,10 +175,10 @@ class TestWorkflowGeneratorService:
|
||||
@patch("services.workflow_generator_service.format_tool_catalogue")
|
||||
def test_defaults_current_graph_to_none_for_create(
|
||||
self,
|
||||
mock_format_catalogue,
|
||||
mock_build_catalogue,
|
||||
mock_model_manager,
|
||||
mock_workflow_generator,
|
||||
mock_format_catalogue: MagicMock,
|
||||
mock_build_catalogue: MagicMock,
|
||||
mock_model_manager: MagicMock,
|
||||
mock_workflow_generator: MagicMock,
|
||||
):
|
||||
"""Omitting current_graph (the `/create` path) forwards None to the runner."""
|
||||
mock_model_manager.for_tenant.return_value.get_model_instance.return_value = MagicMock()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
@ -6,7 +7,7 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
def prepare_example_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions.yaml").write_text(
|
||||
dedent(
|
||||
@ -25,7 +26,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
def prepare_empty_commented_positions_yaml(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(
|
||||
dedent(
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
@ -11,7 +12,7 @@ NON_EXISTING_YAML_FILE = "non_existing_file.yaml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
def prepare_example_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
|
||||
file_path.write_text(
|
||||
@ -34,7 +35,7 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
def prepare_invalid_yaml_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
|
||||
file_path.write_text(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user