mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
test: improve unit tests for controllers.service_api (#32073)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
parent
212756c315
commit
d773096146
@ -3,7 +3,8 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
@ -17,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
@ -65,6 +66,12 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
@ -104,6 +111,12 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
@ -161,6 +174,12 @@ class PipelineRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
||||
@ -124,3 +124,38 @@ def _configure_session_factory(_unit_test_engine):
|
||||
session_factory.get_session_maker()
|
||||
except RuntimeError:
|
||||
configure_session_factory(_unit_test_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
|
||||
"""
|
||||
Helper to set up the mock DB query chain for tenant/account authentication.
|
||||
|
||||
This configures the mock to return (tenant, account) for the join query used
|
||||
by validate_app_token and validate_dataset_token decorators.
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_account: Mock account object to return
|
||||
"""
|
||||
query = mock_db.session.query.return_value
|
||||
join_chain = query.join.return_value.join.return_value
|
||||
where_chain = join_chain.where.return_value
|
||||
where_chain.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
|
||||
|
||||
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
"""
|
||||
Helper to set up the mock DB query chain for dataset tenant authentication.
|
||||
|
||||
This configures the mock to return (tenant, tenant_account) for the where chain
|
||||
query used by validate_dataset_token decorator.
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_ta: Mock tenant account object to return
|
||||
"""
|
||||
query = mock_db.session.query.return_value
|
||||
where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value
|
||||
where_chain.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Unit tests for Service API Annotation controller.
|
||||
|
||||
Tests coverage for:
|
||||
- AnnotationCreatePayload Pydantic model validation
|
||||
- AnnotationReplyActionPayload Pydantic model validation
|
||||
- Error patterns and validation logic
|
||||
|
||||
Note: API endpoint tests for annotation controllers are complex due to:
|
||||
- @validate_app_token decorator requiring full Flask-SQLAlchemy setup
|
||||
- @edit_permission_required decorator checking current_user permissions
|
||||
- These are better covered by integration tests
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask_restx.api import HTTPStatus
|
||||
|
||||
from controllers.service_api.app.annotation import (
|
||||
AnnotationCreatePayload,
|
||||
AnnotationListApi,
|
||||
AnnotationReplyActionApi,
|
||||
AnnotationReplyActionPayload,
|
||||
AnnotationReplyActionStatusApi,
|
||||
AnnotationUpdateDeleteApi,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnnotationCreatePayload:
|
||||
"""Test suite for AnnotationCreatePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_question_and_answer(self):
|
||||
"""Test payload with required fields."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="What is AI?",
|
||||
answer="AI is artificial intelligence.",
|
||||
)
|
||||
assert payload.question == "What is AI?"
|
||||
assert payload.answer == "AI is artificial intelligence."
|
||||
|
||||
def test_payload_with_unicode_content(self):
|
||||
"""Test payload with unicode content."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="什么是人工智能?",
|
||||
answer="人工智能是模拟人类智能的技术。",
|
||||
)
|
||||
assert payload.question == "什么是人工智能?"
|
||||
|
||||
def test_payload_with_special_characters(self):
|
||||
"""Test payload with special characters."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="What is <b>AI</b>?",
|
||||
answer="AI & ML are related fields with 100% growth!",
|
||||
)
|
||||
assert "<b>" in payload.question
|
||||
|
||||
|
||||
class TestAnnotationReplyActionPayload:
|
||||
"""Test suite for AnnotationReplyActionPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.8,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-ada-002",
|
||||
)
|
||||
assert payload.score_threshold == 0.8
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_payload_with_different_provider(self):
|
||||
"""Test payload with different embedding provider."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.75,
|
||||
embedding_provider_name="azure_openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.embedding_provider_name == "azure_openai"
|
||||
|
||||
def test_payload_with_zero_threshold(self):
|
||||
"""Test payload with zero score threshold."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.0,
|
||||
embedding_provider_name="local",
|
||||
embedding_model_name="default",
|
||||
)
|
||||
assert payload.score_threshold == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model and Error Pattern Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppModelPatterns:
|
||||
"""Test App model patterns used by annotation controller."""
|
||||
|
||||
def test_app_model_has_required_fields(self):
|
||||
"""Test App model has required fields for annotation operations."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
assert app.id is not None
|
||||
assert app.status == "normal"
|
||||
assert app.enable_api is True
|
||||
|
||||
def test_app_model_disabled_api(self):
|
||||
"""Test app with disabled API access."""
|
||||
app = Mock(spec=App)
|
||||
app.enable_api = False
|
||||
|
||||
assert app.enable_api is False
|
||||
|
||||
def test_app_model_archived_status(self):
|
||||
"""Test app with archived status."""
|
||||
app = Mock(spec=App)
|
||||
app.status = "archived"
|
||||
|
||||
assert app.status == "archived"
|
||||
|
||||
|
||||
class TestAnnotationErrorPatterns:
|
||||
"""Test annotation-related error handling patterns."""
|
||||
|
||||
def test_not_found_error_pattern(self):
|
||||
"""Test NotFound error pattern used in annotation operations."""
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
raise NotFound("Annotation not found.")
|
||||
|
||||
def test_forbidden_error_pattern(self):
|
||||
"""Test Forbidden error pattern."""
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
raise Forbidden("Permission denied.")
|
||||
|
||||
def test_value_error_for_job_not_found(self):
|
||||
"""Test ValueError pattern for job not found."""
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
|
||||
class TestAnnotationReplyActionApi:
|
||||
def test_enable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
enable_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/annotation-reply/enable",
|
||||
method="POST",
|
||||
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, action="enable")
|
||||
|
||||
assert status == 200
|
||||
enable_mock.assert_called_once()
|
||||
|
||||
def test_disable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
disable_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/annotation-reply/disable",
|
||||
method="POST",
|
||||
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, action="disable")
|
||||
|
||||
assert status == 200
|
||||
disable_mock.assert_called_once()
|
||||
|
||||
|
||||
class TestAnnotationReplyActionStatusApi:
|
||||
def test_missing_job(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler(api, app_model=app_model, job_id="j1", action="enable")
|
||||
|
||||
def test_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _get(key):
|
||||
if "error" in key:
|
||||
return b"oops"
|
||||
return b"error"
|
||||
|
||||
monkeypatch.setattr(redis_client, "get", _get)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
|
||||
|
||||
assert status == 200
|
||||
assert response["job_status"] == "error"
|
||||
assert response["error_msg"] == "oops"
|
||||
|
||||
|
||||
class TestAnnotationListApi:
|
||||
def test_get(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"get_annotation_list_by_app_id",
|
||||
lambda *_args, **_kwargs: ([annotation], 1),
|
||||
)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations?page=1&limit=1", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response["total"] == 1
|
||||
|
||||
def test_create(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"insert_app_annotation_directly",
|
||||
lambda *_args, **_kwargs: annotation,
|
||||
)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
|
||||
response, status = handler(api, app_model=app_model)
|
||||
|
||||
assert status == HTTPStatus.CREATED
|
||||
assert response["question"] == "q"
|
||||
|
||||
|
||||
class TestAnnotationUpdateDeleteApi:
|
||||
def test_update_delete(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"update_app_annotation_directly",
|
||||
lambda *_args, **_kwargs: annotation,
|
||||
)
|
||||
delete_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
|
||||
|
||||
api = AnnotationUpdateDeleteApi()
|
||||
put_handler = _unwrap(api.put)
|
||||
delete_handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):
|
||||
response = put_handler(api, app_model=app_model, annotation_id="1")
|
||||
|
||||
assert response["answer"] == "a"
|
||||
|
||||
with app.test_request_context("/apps/annotations/1", method="DELETE"):
|
||||
response, status = delete_handler(api, app_model=app_model, annotation_id="1")
|
||||
|
||||
assert status == 204
|
||||
delete_mock.assert_called_once()
|
||||
496
api/tests/unit_tests/controllers/service_api/app/test_app.py
Normal file
496
api/tests/unit_tests/controllers/service_api/app/test_app.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""
|
||||
Unit tests for Service API App controllers
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
"""Test suite for AppParameterApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_for_chat_app(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving parameters for a chat app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_config = Mock()
|
||||
mock_config.id = str(uuid.uuid4())
|
||||
mock_config.to_dict.return_value = {
|
||||
"user_input_form": [{"type": "text", "label": "Name", "variable": "name", "required": True}],
|
||||
"suggested_questions": [],
|
||||
}
|
||||
mock_app_model.app_model_config = mock_config
|
||||
mock_app_model.workflow = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
# Mock DB queries for app and tenant
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert "opening_statement" in response
|
||||
assert "suggested_questions" in response
|
||||
assert "user_input_form" in response
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_for_workflow_app(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving parameters for a workflow app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.features_dict = {"suggested_questions": []}
|
||||
mock_workflow.user_input_form.return_value = [{"type": "text", "label": "Input", "variable": "input"}]
|
||||
mock_app_model.workflow = mock_workflow
|
||||
mock_app_model.app_model_config = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert "user_input_form" in response
|
||||
assert "opening_statement" in response
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_raises_error_when_chat_config_missing(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when chat app has no config."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.app_model_config = None
|
||||
mock_app_model.workflow = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
with pytest.raises(AppUnavailableError):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_raises_error_when_workflow_missing(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_app_model.workflow = None
|
||||
mock_app_model.app_model_config = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
with pytest.raises(AppUnavailableError):
|
||||
api.get()
|
||||
|
||||
|
||||
class TestAppMetaApi:
|
||||
"""Test suite for AppMetaApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.app.app.AppService")
|
||||
def test_get_app_meta(
|
||||
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving app metadata via AppService."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.get_app_meta.return_value = {
|
||||
"tool_icons": {},
|
||||
"AgentIcons": {},
|
||||
}
|
||||
mock_app_service.return_value = mock_service_instance
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppMetaApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
mock_service_instance.get_app_meta.assert_called_once_with(mock_app_model)
|
||||
assert response == {"tool_icons": {}, "AgentIcons": {}}
|
||||
|
||||
|
||||
class TestAppInfoApi:
|
||||
"""Test suite for AppInfoApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model with all required attributes."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.name = "Test App"
|
||||
app.description = "A test application"
|
||||
app.mode = AppMode.CHAT
|
||||
app.author_name = "Test Author"
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
# Mock tags relationship
|
||||
mock_tag = Mock()
|
||||
mock_tag.name = "test-tag"
|
||||
app.tags = [mock_tag]
|
||||
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving basic app information."""
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["name"] == "Test App"
|
||||
assert response["description"] == "A test application"
|
||||
assert response["tags"] == ["test-tag"]
|
||||
assert response["mode"] == AppMode.CHAT
|
||||
assert response["author_name"] == "Test Author"
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_with_multiple_tags(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app
|
||||
):
|
||||
"""Test retrieving app info with multiple tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "Multi Tag App"
|
||||
mock_app.description = "App with multiple tags"
|
||||
mock_app.mode = AppMode.WORKFLOW
|
||||
mock_app.author_name = "Author"
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
tag1, tag2, tag3 = Mock(), Mock(), Mock()
|
||||
tag1.name = "tag-one"
|
||||
tag2.name = "tag-two"
|
||||
tag3.name = "tag-three"
|
||||
mock_app.tags = [tag1, tag2, tag3]
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["tags"] == ["tag-one", "tag-two", "tag-three"]
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
|
||||
"""Test retrieving app info when app has no tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "No Tags App"
|
||||
mock_app.description = "App without tags"
|
||||
mock_app.mode = AppMode.COMPLETION
|
||||
mock_app.author_name = "Author"
|
||||
mock_app.tags = []
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["tags"] == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"app_mode",
|
||||
[AppMode.CHAT, AppMode.COMPLETION, AppMode.WORKFLOW, AppMode.ADVANCED_CHAT],
|
||||
)
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_returns_correct_mode(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
|
||||
):
|
||||
"""Test that all app modes are correctly returned."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "Test"
|
||||
mock_app.description = "Test"
|
||||
mock_app.mode = app_mode
|
||||
mock_app.author_name = "Test"
|
||||
mock_app.tags = []
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["mode"] == app_mode
|
||||
298
api/tests/unit_tests/controllers/service_api/app/test_audio.py
Normal file
298
api/tests/unit_tests/controllers/service_api/app/test_audio.py
Normal file
@ -0,0 +1,298 @@
|
||||
"""
|
||||
Unit tests for Service API Audio controller.
|
||||
|
||||
Tests coverage for:
|
||||
- TextToAudioPayload Pydantic model validation
|
||||
- Error mapping patterns between service and API errors
|
||||
- AudioService method interfaces
|
||||
"""
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api.app.audio import AudioApi, TextApi, TextToAudioPayload
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
"""Test suite for TextToAudioPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
payload = TextToAudioPayload(
|
||||
message_id="msg_123",
|
||||
voice="nova",
|
||||
text="Hello, this is a test.",
|
||||
streaming=False,
|
||||
)
|
||||
assert payload.message_id == "msg_123"
|
||||
assert payload.voice == "nova"
|
||||
assert payload.text == "Hello, this is a test."
|
||||
assert payload.streaming is False
|
||||
|
||||
def test_payload_with_defaults(self):
|
||||
"""Test payload with default values."""
|
||||
payload = TextToAudioPayload()
|
||||
assert payload.message_id is None
|
||||
assert payload.voice is None
|
||||
assert payload.text is None
|
||||
assert payload.streaming is None
|
||||
|
||||
def test_payload_with_only_text(self):
|
||||
"""Test payload with only text field."""
|
||||
payload = TextToAudioPayload(text="Simple text to speech")
|
||||
assert payload.text == "Simple text to speech"
|
||||
assert payload.voice is None
|
||||
assert payload.message_id is None
|
||||
|
||||
def test_payload_with_streaming_true(self):
|
||||
"""Test payload with streaming enabled."""
|
||||
payload = TextToAudioPayload(
|
||||
text="Streaming test",
|
||||
streaming=True,
|
||||
)
|
||||
assert payload.streaming is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioService Interface Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test AudioService method interfaces exist."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio Service Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test suite for AudioService interface methods."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
class TestServiceErrorTypes:
|
||||
"""Test service error types used by audio controller."""
|
||||
|
||||
def test_no_audio_uploaded_service_error(self):
|
||||
"""Test NoAudioUploadedServiceError exists."""
|
||||
error = NoAudioUploadedServiceError()
|
||||
assert error is not None
|
||||
|
||||
def test_audio_too_large_service_error(self):
|
||||
"""Test AudioTooLargeServiceError with message."""
|
||||
error = AudioTooLargeServiceError("File too large")
|
||||
assert "File too large" in str(error)
|
||||
|
||||
def test_unsupported_audio_type_service_error(self):
|
||||
"""Test UnsupportedAudioTypeServiceError exists."""
|
||||
error = UnsupportedAudioTypeServiceError()
|
||||
assert error is not None
|
||||
|
||||
def test_provider_not_support_speech_to_text_service_error(self):
|
||||
"""Test ProviderNotSupportSpeechToTextServiceError exists."""
|
||||
error = ProviderNotSupportSpeechToTextServiceError()
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mocked Behavior Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceMockedBehavior:
|
||||
"""Test AudioService behavior with mocked methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock app model."""
|
||||
from models.model import App
|
||||
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create mock file upload."""
|
||||
mock = Mock()
|
||||
mock.filename = "test_audio.mp3"
|
||||
mock.content_type = "audio/mpeg"
|
||||
return mock
|
||||
|
||||
@patch.object(AudioService, "transcript_asr")
|
||||
def test_transcript_asr_returns_response(self, mock_asr, mock_app, mock_file):
|
||||
"""Test ASR transcription returns response dict."""
|
||||
mock_response = {"text": "Transcribed text"}
|
||||
mock_asr.return_value = mock_response
|
||||
|
||||
result = AudioService.transcript_asr(
|
||||
app_model=mock_app,
|
||||
file=mock_file,
|
||||
end_user="user_123",
|
||||
)
|
||||
|
||||
assert result["text"] == "Transcribed text"
|
||||
|
||||
@patch.object(AudioService, "transcript_tts")
|
||||
def test_transcript_tts_returns_response(self, mock_tts, mock_app):
|
||||
"""Test TTS transcription returns response."""
|
||||
mock_response = {"audio": "base64_audio_data"}
|
||||
mock_tts.return_value = mock_response
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=mock_app,
|
||||
text="Hello world",
|
||||
voice="nova",
|
||||
end_user="user_123",
|
||||
message_id="msg_123",
|
||||
)
|
||||
|
||||
assert result["audio"] == "base64_audio_data"
|
||||
|
||||
|
||||
class TestAudioApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestTextApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
with app.test_request_context(
|
||||
"/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
|
||||
)
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
@ -0,0 +1,524 @@
|
||||
"""
|
||||
Unit tests for Service API Completion controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- CompletionRequestPayload and ChatRequestPayload Pydantic models
|
||||
- App mode validation logic
|
||||
- Error mapping from service layer to HTTP errors
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation (especially UUID normalization)
|
||||
- Error types and their mappings
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.app.completion import (
|
||||
ChatApi,
|
||||
ChatRequestPayload,
|
||||
ChatStopApi,
|
||||
CompletionApi,
|
||||
CompletionRequestPayload,
|
||||
CompletionStopApi,
|
||||
)
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCompletionRequestPayload:
|
||||
"""Test suite for CompletionRequestPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_fields(self):
|
||||
"""Test payload with only required inputs field."""
|
||||
payload = CompletionRequestPayload(inputs={"name": "test"})
|
||||
assert payload.inputs == {"name": "test"}
|
||||
assert payload.query == ""
|
||||
assert payload.files is None
|
||||
assert payload.response_mode is None
|
||||
assert payload.retriever_from == "dev"
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
payload = CompletionRequestPayload(
|
||||
inputs={"user_input": "Hello"},
|
||||
query="What is AI?",
|
||||
files=[{"type": "image", "url": "http://example.com/image.png"}],
|
||||
response_mode="streaming",
|
||||
retriever_from="api",
|
||||
)
|
||||
assert payload.inputs == {"user_input": "Hello"}
|
||||
assert payload.query == "What is AI?"
|
||||
assert payload.files == [{"type": "image", "url": "http://example.com/image.png"}]
|
||||
assert payload.response_mode == "streaming"
|
||||
assert payload.retriever_from == "api"
|
||||
|
||||
def test_payload_response_mode_blocking(self):
|
||||
"""Test payload with blocking response mode."""
|
||||
payload = CompletionRequestPayload(inputs={}, response_mode="blocking")
|
||||
assert payload.response_mode == "blocking"
|
||||
|
||||
def test_payload_empty_inputs(self):
|
||||
"""Test payload with empty inputs dict."""
|
||||
payload = CompletionRequestPayload(inputs={})
|
||||
assert payload.inputs == {}
|
||||
|
||||
def test_payload_complex_inputs(self):
|
||||
"""Test payload with complex nested inputs."""
|
||||
complex_inputs = {
|
||||
"user": {"name": "Alice", "age": 30},
|
||||
"context": ["item1", "item2"],
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
}
|
||||
payload = CompletionRequestPayload(inputs=complex_inputs)
|
||||
assert payload.inputs == complex_inputs
|
||||
|
||||
|
||||
class TestChatRequestPayload:
|
||||
"""Test suite for ChatRequestPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_fields(self):
|
||||
"""Test payload with required fields."""
|
||||
payload = ChatRequestPayload(inputs={"key": "value"}, query="Hello")
|
||||
assert payload.inputs == {"key": "value"}
|
||||
assert payload.query == "Hello"
|
||||
assert payload.conversation_id is None
|
||||
assert payload.auto_generate_name is True
|
||||
|
||||
def test_payload_normalizes_valid_uuid_conversation_id(self):
|
||||
"""Test that valid UUID conversation_id is normalized."""
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=valid_uuid)
|
||||
assert payload.conversation_id == valid_uuid
|
||||
|
||||
def test_payload_normalizes_empty_string_conversation_id_to_none(self):
|
||||
"""Test that empty string conversation_id becomes None."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id="")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
def test_payload_normalizes_whitespace_conversation_id_to_none(self):
|
||||
"""Test that whitespace-only conversation_id becomes None."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=" ")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
def test_payload_rejects_invalid_uuid_conversation_id(self):
|
||||
"""Test that invalid UUID format raises ValueError."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ChatRequestPayload(inputs={}, query="test", conversation_id="not-a-uuid")
|
||||
assert "valid UUID" in str(exc_info.value)
|
||||
|
||||
def test_payload_with_workflow_id(self):
|
||||
"""Test payload with workflow_id for advanced chat."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", workflow_id="workflow_123")
|
||||
assert payload.workflow_id == "workflow_123"
|
||||
|
||||
def test_payload_streaming_mode(self):
|
||||
"""Test payload with streaming response mode."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", response_mode="streaming")
|
||||
assert payload.response_mode == "streaming"
|
||||
|
||||
def test_payload_auto_generate_name_false(self):
|
||||
"""Test payload with auto_generate_name explicitly false."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", auto_generate_name=False)
|
||||
assert payload.auto_generate_name is False
|
||||
|
||||
def test_payload_with_files(self):
|
||||
"""Test payload with file attachments."""
|
||||
files = [
|
||||
{"type": "image", "transfer_method": "remote_url", "url": "http://example.com/img.png"},
|
||||
{"type": "document", "transfer_method": "local_file", "upload_file_id": "file_123"},
|
||||
]
|
||||
payload = ChatRequestPayload(inputs={}, query="test", files=files)
|
||||
assert payload.files == files
|
||||
assert len(payload.files) == 2
|
||||
|
||||
|
||||
class TestCompletionErrorMappings:
|
||||
"""Test error type mappings for completion endpoints."""
|
||||
|
||||
def test_conversation_not_exists_error_exists(self):
|
||||
"""Test ConversationNotExistsError can be raised."""
|
||||
error = services.errors.conversation.ConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
|
||||
|
||||
def test_conversation_completed_error_exists(self):
|
||||
"""Test ConversationCompletedError can be raised."""
|
||||
error = services.errors.conversation.ConversationCompletedError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
|
||||
|
||||
api_error = ConversationCompletedError()
|
||||
assert api_error is not None
|
||||
|
||||
def test_app_model_config_broken_error_exists(self):
|
||||
"""Test AppModelConfigBrokenError can be raised."""
|
||||
error = services.errors.app_model_config.AppModelConfigBrokenError()
|
||||
assert isinstance(error, services.errors.app_model_config.AppModelConfigBrokenError)
|
||||
|
||||
api_error = AppUnavailableError()
|
||||
assert api_error is not None
|
||||
|
||||
def test_workflow_not_found_error_exists(self):
|
||||
"""Test WorkflowNotFoundError can be raised."""
|
||||
error = WorkflowNotFoundError("Workflow not found")
|
||||
assert isinstance(error, WorkflowNotFoundError)
|
||||
|
||||
def test_is_draft_workflow_error_exists(self):
|
||||
"""Test IsDraftWorkflowError can be raised."""
|
||||
error = IsDraftWorkflowError("Workflow is in draft state")
|
||||
assert isinstance(error, IsDraftWorkflowError)
|
||||
|
||||
def test_workflow_id_format_error_exists(self):
|
||||
"""Test WorkflowIdFormatError can be raised."""
|
||||
error = WorkflowIdFormatError("Invalid workflow ID format")
|
||||
assert isinstance(error, WorkflowIdFormatError)
|
||||
|
||||
def test_invoke_rate_limit_error_exists(self):
|
||||
"""Test InvokeRateLimitError can be raised."""
|
||||
error = InvokeRateLimitError("Rate limit exceeded")
|
||||
assert isinstance(error, InvokeRateLimitError)
|
||||
|
||||
|
||||
class TestAppModeValidation:
|
||||
"""Test app mode validation logic patterns."""
|
||||
|
||||
def test_completion_mode_is_valid_for_completion_endpoint(self):
|
||||
"""Test that COMPLETION mode is valid for completion endpoints."""
|
||||
assert AppMode.COMPLETION == AppMode.COMPLETION
|
||||
|
||||
def test_chat_modes_are_distinct_from_completion(self):
|
||||
"""Test that chat modes are distinct from completion mode."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.COMPLETION not in chat_modes
|
||||
|
||||
def test_workflow_mode_is_distinct_from_chat_modes(self):
|
||||
"""Test that WORKFLOW mode is not a chat mode."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.WORKFLOW not in chat_modes
|
||||
|
||||
def test_not_chat_app_error_can_be_raised(self):
|
||||
"""Test NotChatAppError can be raised for non-chat apps."""
|
||||
error = NotChatAppError()
|
||||
assert error is not None
|
||||
|
||||
def test_all_app_modes_are_defined(self):
|
||||
"""Test that all expected app modes are defined."""
|
||||
expected_modes = ["COMPLETION", "CHAT", "AGENT_CHAT", "ADVANCED_CHAT", "WORKFLOW", "CHANNEL", "RAG_PIPELINE"]
|
||||
for mode_name in expected_modes:
|
||||
assert hasattr(AppMode, mode_name), f"AppMode.{mode_name} should exist"
|
||||
|
||||
|
||||
class TestAppGenerateService:
|
||||
"""Test AppGenerateService integration patterns."""
|
||||
|
||||
def test_generate_method_exists(self):
|
||||
"""Test that AppGenerateService.generate method exists."""
|
||||
assert hasattr(AppGenerateService, "generate")
|
||||
assert callable(AppGenerateService.generate)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_returns_response(self, mock_generate):
|
||||
"""Test that generate returns expected response format."""
|
||||
expected = {"answer": "Hello!"}
|
||||
mock_generate.return_value = expected
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={"query": "Hi"}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_conversation_not_exists(self, mock_generate):
|
||||
"""Test generate raises ConversationNotExistsError."""
|
||||
mock_generate.side_effect = services.errors.conversation.ConversationNotExistsError()
|
||||
|
||||
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_quota_exceeded(self, mock_generate):
|
||||
"""Test generate raises QuotaExceededError."""
|
||||
mock_generate.side_effect = QuotaExceededError()
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_invoke_error(self, mock_generate):
|
||||
"""Test generate raises InvokeError."""
|
||||
mock_generate.side_effect = InvokeError("Model invocation failed")
|
||||
|
||||
with pytest.raises(InvokeError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionControllerLogic:
|
||||
"""Test CompletionApi and ChatApi controller logic directly."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
from flask import Flask
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
@patch("controllers.service_api.app.completion.AppGenerateService")
|
||||
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
|
||||
"""Test CompletionApi.post success path."""
|
||||
from controllers.service_api.app.completion import CompletionApi
|
||||
|
||||
# Setup mocks
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
payload_dict = {"inputs": {"text": "hello"}, "response_mode": "blocking"}
|
||||
mock_service_api_ns.payload = payload_dict
|
||||
mock_generate_service.generate.return_value = {"text": "response"}
|
||||
|
||||
with app.test_request_context():
|
||||
# Helper for compact_generate_response logic check
|
||||
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
|
||||
mock_compact.return_value = {"text": "compacted"}
|
||||
|
||||
api = CompletionApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
|
||||
|
||||
assert response == {"text": "compacted"}
|
||||
mock_generate_service.generate.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
|
||||
"""Test CompletionApi.post with wrong app mode."""
|
||||
from controllers.service_api.app.completion import CompletionApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT # Wrong mode
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
with app.test_request_context():
|
||||
with pytest.raises(AppUnavailableError):
|
||||
CompletionApi().post.__wrapped__(CompletionApi(), mock_app_model, mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
@patch("controllers.service_api.app.completion.AppGenerateService")
|
||||
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
|
||||
"""Test ChatApi.post success path."""
|
||||
from controllers.service_api.app.completion import ChatApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
payload_dict = {"inputs": {}, "query": "hello", "response_mode": "blocking"}
|
||||
mock_service_api_ns.payload = payload_dict
|
||||
mock_generate_service.generate.return_value = {"text": "response"}
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
|
||||
mock_compact.return_value = {"text": "compacted"}
|
||||
|
||||
api = ChatApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
|
||||
assert response == {"text": "compacted"}
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
|
||||
"""Test ChatApi.post with wrong app mode."""
|
||||
from controllers.service_api.app.completion import ChatApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION # Wrong mode
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
with app.test_request_context():
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.completion.AppTaskService")
|
||||
def test_completion_stop_api_success(self, mock_task_service, app):
|
||||
"""Test CompletionStopApi.post success."""
|
||||
from controllers.service_api.app.completion import CompletionStopApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
mock_end_user.id = "user_id"
|
||||
|
||||
with app.test_request_context():
|
||||
api = CompletionStopApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
|
||||
|
||||
assert response == ({"result": "success"}, 200)
|
||||
mock_task_service.stop_task.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.app.completion.AppTaskService")
|
||||
def test_chat_stop_api_success(self, mock_task_service, app):
|
||||
"""Test ChatStopApi.post success."""
|
||||
from controllers.service_api.app.completion import ChatStopApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
mock_end_user.id = "user_id"
|
||||
|
||||
with app.test_request_context():
|
||||
api = ChatStopApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
|
||||
|
||||
assert response == ({"result": "success"}, 200)
|
||||
mock_task_service.stop_task.assert_called_once()
|
||||
|
||||
|
||||
class TestChatRequestPayloadController:
|
||||
def test_normalizes_conversation_id(self) -> None:
|
||||
payload = ChatRequestPayload.model_validate(
|
||||
{"inputs": {}, "query": "hi", "conversation_id": " ", "response_mode": "blocking"}
|
||||
)
|
||||
assert payload.conversation_id is None
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hi", "conversation_id": "bad-id"})
|
||||
|
||||
|
||||
class TestCompletionApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestCompletionStopApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/completion-messages/1/stop", method="POST"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stop_mock = Mock()
|
||||
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
|
||||
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/completion-messages/1/stop", method="POST"):
|
||||
response, status = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
class TestChatApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_workflow_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestChatStopApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = ChatStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/chat-messages/1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
@ -0,0 +1,597 @@
|
||||
"""
|
||||
Unit tests for Service API Conversation controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- ConversationListQuery, ConversationRenamePayload Pydantic models
|
||||
- ConversationVariablesQuery with SQL injection prevention
|
||||
- ConversationVariableUpdatePayload
|
||||
- App mode validation for chat-only endpoints
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation including security checks
|
||||
- SQL injection prevention in variable name filtering
|
||||
- Error types and mappings
|
||||
"""
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.app.conversation import (
|
||||
ConversationApi,
|
||||
ConversationDetailApi,
|
||||
ConversationListQuery,
|
||||
ConversationRenameApi,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariableDetailApi,
|
||||
ConversationVariablesApi,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
ConversationVariableTypeMismatchError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
"""Test suite for ConversationListQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = ConversationListQuery()
|
||||
assert query.last_id is None
|
||||
assert query.limit == 20
|
||||
assert query.sort_by == "-updated_at"
|
||||
|
||||
def test_query_with_last_id(self):
|
||||
"""Test query with pagination last_id."""
|
||||
last_id = str(uuid.uuid4())
|
||||
query = ConversationListQuery(last_id=last_id)
|
||||
assert str(query.last_id) == last_id
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query respects limit boundaries."""
|
||||
query_min = ConversationListQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = ConversationListQuery(limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sort_by",
|
||||
[
|
||||
"created_at",
|
||||
"-created_at",
|
||||
"updated_at",
|
||||
"-updated_at",
|
||||
],
|
||||
)
|
||||
def test_query_valid_sort_options(self, sort_by):
|
||||
"""Test all valid sort_by options."""
|
||||
query = ConversationListQuery(sort_by=sort_by)
|
||||
assert query.sort_by == sort_by
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
"""Test suite for ConversationRenamePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_name(self):
|
||||
"""Test payload with explicit name."""
|
||||
payload = ConversationRenamePayload(name="My New Chat", auto_generate=False)
|
||||
assert payload.name == "My New Chat"
|
||||
assert payload.auto_generate is False
|
||||
|
||||
def test_payload_with_auto_generate(self):
|
||||
"""Test payload with auto_generate enabled."""
|
||||
payload = ConversationRenamePayload(auto_generate=True)
|
||||
assert payload.auto_generate is True
|
||||
assert payload.name is None
|
||||
|
||||
def test_payload_requires_name_when_auto_generate_false(self):
|
||||
"""Test that name is required when auto_generate is False."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
assert "name is required when auto_generate is false" in str(exc_info.value)
|
||||
|
||||
def test_payload_requires_non_empty_name_when_auto_generate_false(self):
|
||||
"""Test that empty string name is rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(name="", auto_generate=False)
|
||||
|
||||
def test_payload_requires_non_whitespace_name_when_auto_generate_false(self):
|
||||
"""Test that whitespace-only name is rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(name=" ", auto_generate=False)
|
||||
|
||||
def test_payload_name_with_special_characters(self):
|
||||
"""Test payload with name containing special characters."""
|
||||
payload = ConversationRenamePayload(name="Chat #1 - (Test) & More!", auto_generate=False)
|
||||
assert payload.name == "Chat #1 - (Test) & More!"
|
||||
|
||||
def test_payload_name_with_unicode(self):
|
||||
"""Test payload with Unicode characters in name."""
|
||||
payload = ConversationRenamePayload(name="对话 📝 Чат", auto_generate=False)
|
||||
assert payload.name == "对话 📝 Чат"
|
||||
|
||||
|
||||
class TestConversationVariablesQuery:
|
||||
"""Test suite for ConversationVariablesQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = ConversationVariablesQuery()
|
||||
assert query.last_id is None
|
||||
assert query.limit == 20
|
||||
assert query.variable_name is None
|
||||
|
||||
def test_query_with_variable_name(self):
|
||||
"""Test query with valid variable_name filter."""
|
||||
query = ConversationVariablesQuery(variable_name="user_preference")
|
||||
assert query.variable_name == "user_preference"
|
||||
|
||||
def test_query_allows_hyphen_in_variable_name(self):
|
||||
"""Test that hyphens are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="my-variable")
|
||||
assert query.variable_name == "my-variable"
|
||||
|
||||
def test_query_allows_underscore_in_variable_name(self):
|
||||
"""Test that underscores are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="my_variable")
|
||||
assert query.variable_name == "my_variable"
|
||||
|
||||
def test_query_allows_period_in_variable_name(self):
|
||||
"""Test that periods are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="config.setting")
|
||||
assert query.variable_name == "config.setting"
|
||||
|
||||
def test_query_rejects_sql_injection_single_quote(self):
|
||||
"""Test that single quotes are rejected (SQL injection prevention)."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="'; DROP TABLE users;--")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_double_quote(self):
|
||||
"""Test that double quotes are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name='name"test')
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_semicolon(self):
|
||||
"""Test that semicolons are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name;malicious")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_comment(self):
|
||||
"""Test that SQL comments are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name--comment")
|
||||
assert "invalid characters" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_special_characters(self):
|
||||
"""Test that special characters are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name@domain")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_backticks(self):
|
||||
"""Test that backticks are rejected (SQL injection prevention)."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="`table`")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_pagination_limits(self):
|
||||
"""Test query pagination limit boundaries."""
|
||||
query_min = ConversationVariablesQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = ConversationVariablesQuery(limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
|
||||
class TestConversationVariableUpdatePayload:
|
||||
"""Test suite for ConversationVariableUpdatePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_string_value(self):
|
||||
"""Test payload with string value."""
|
||||
payload = ConversationVariableUpdatePayload(value="hello")
|
||||
assert payload.value == "hello"
|
||||
|
||||
def test_payload_with_number_value(self):
|
||||
"""Test payload with number value."""
|
||||
payload = ConversationVariableUpdatePayload(value=42)
|
||||
assert payload.value == 42
|
||||
|
||||
def test_payload_with_float_value(self):
|
||||
"""Test payload with float value."""
|
||||
payload = ConversationVariableUpdatePayload(value=3.14159)
|
||||
assert payload.value == 3.14159
|
||||
|
||||
def test_payload_with_list_value(self):
|
||||
"""Test payload with list value."""
|
||||
payload = ConversationVariableUpdatePayload(value=["a", "b", "c"])
|
||||
assert payload.value == ["a", "b", "c"]
|
||||
|
||||
def test_payload_with_dict_value(self):
|
||||
"""Test payload with dictionary value."""
|
||||
payload = ConversationVariableUpdatePayload(value={"key": "value"})
|
||||
assert payload.value == {"key": "value"}
|
||||
|
||||
def test_payload_with_none_value(self):
|
||||
"""Test payload with None value."""
|
||||
payload = ConversationVariableUpdatePayload(value=None)
|
||||
assert payload.value is None
|
||||
|
||||
def test_payload_with_boolean_value(self):
|
||||
"""Test payload with boolean value."""
|
||||
payload = ConversationVariableUpdatePayload(value=True)
|
||||
assert payload.value is True
|
||||
|
||||
def test_payload_with_nested_structure(self):
|
||||
"""Test payload with deeply nested structure."""
|
||||
nested = {"level1": {"level2": {"level3": ["a", "b", {"c": 123}]}}}
|
||||
payload = ConversationVariableUpdatePayload(value=nested)
|
||||
assert payload.value == nested
|
||||
|
||||
|
||||
class TestConversationAppModeValidation:
|
||||
"""Test app mode validation for conversation endpoints."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
[
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
],
|
||||
)
|
||||
def test_chat_modes_are_valid_for_conversation_endpoints(self, mode):
|
||||
"""Test that all chat modes are valid for conversation endpoints.
|
||||
|
||||
Verifies that CHAT, AGENT_CHAT, and ADVANCED_CHAT modes pass
|
||||
validation without raising NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = mode
|
||||
|
||||
# Validation should pass without raising for chat modes
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
|
||||
def test_completion_mode_is_invalid_for_conversation_endpoints(self):
|
||||
"""Test that COMPLETION mode is invalid for conversation endpoints.
|
||||
|
||||
Verifies that calling a conversation endpoint with a COMPLETION mode
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
with pytest.raises(NotChatAppError):
|
||||
raise NotChatAppError()
|
||||
|
||||
def test_workflow_mode_is_invalid_for_conversation_endpoints(self):
|
||||
"""Test that WORKFLOW mode is invalid for conversation endpoints.
|
||||
|
||||
Verifies that calling a conversation endpoint with a WORKFLOW mode
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
with pytest.raises(NotChatAppError):
|
||||
raise NotChatAppError()
|
||||
|
||||
|
||||
class TestConversationErrorTypes:
|
||||
"""Test conversation-related error types."""
|
||||
|
||||
def test_conversation_not_exists_error(self):
|
||||
"""Test ConversationNotExistsError exists and can be raised."""
|
||||
error = services.errors.conversation.ConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
|
||||
|
||||
def test_conversation_completed_error(self):
|
||||
"""Test ConversationCompletedError exists."""
|
||||
error = services.errors.conversation.ConversationCompletedError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
|
||||
|
||||
def test_last_conversation_not_exists_error(self):
|
||||
"""Test LastConversationNotExistsError exists."""
|
||||
error = services.errors.conversation.LastConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.LastConversationNotExistsError)
|
||||
|
||||
def test_conversation_variable_not_exists_error(self):
|
||||
"""Test ConversationVariableNotExistsError exists."""
|
||||
error = services.errors.conversation.ConversationVariableNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationVariableNotExistsError)
|
||||
|
||||
def test_conversation_variable_type_mismatch_error(self):
|
||||
"""Test ConversationVariableTypeMismatchError exists."""
|
||||
error = services.errors.conversation.ConversationVariableTypeMismatchError("Type mismatch")
|
||||
assert isinstance(error, services.errors.conversation.ConversationVariableTypeMismatchError)
|
||||
|
||||
|
||||
class TestConversationService:
|
||||
"""Test ConversationService integration patterns."""
|
||||
|
||||
def test_pagination_by_last_id_method_exists(self):
|
||||
"""Test that ConversationService.pagination_by_last_id exists."""
|
||||
assert hasattr(ConversationService, "pagination_by_last_id")
|
||||
assert callable(ConversationService.pagination_by_last_id)
|
||||
|
||||
def test_delete_method_exists(self):
|
||||
"""Test that ConversationService.delete exists."""
|
||||
assert hasattr(ConversationService, "delete")
|
||||
assert callable(ConversationService.delete)
|
||||
|
||||
def test_rename_method_exists(self):
|
||||
"""Test that ConversationService.rename exists."""
|
||||
assert hasattr(ConversationService, "rename")
|
||||
assert callable(ConversationService.rename)
|
||||
|
||||
def test_get_conversational_variable_method_exists(self):
|
||||
"""Test that ConversationService.get_conversational_variable exists."""
|
||||
assert hasattr(ConversationService, "get_conversational_variable")
|
||||
assert callable(ConversationService.get_conversational_variable)
|
||||
|
||||
def test_update_conversation_variable_method_exists(self):
|
||||
"""Test that ConversationService.update_conversation_variable exists."""
|
||||
assert hasattr(ConversationService, "update_conversation_variable")
|
||||
assert callable(ConversationService.update_conversation_variable)
|
||||
|
||||
@patch.object(ConversationService, "pagination_by_last_id")
|
||||
def test_pagination_returns_expected_format(self, mock_pagination):
|
||||
"""Test pagination returns expected data format."""
|
||||
mock_result = Mock()
|
||||
mock_result.data = []
|
||||
mock_result.limit = 20
|
||||
mock_result.has_more = False
|
||||
mock_pagination.return_value = mock_result
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=Mock(),
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
assert hasattr(result, "data")
|
||||
assert hasattr(result, "limit")
|
||||
assert hasattr(result, "has_more")
|
||||
|
||||
@patch.object(ConversationService, "rename")
|
||||
def test_rename_returns_conversation(self, mock_rename):
|
||||
"""Test rename returns updated conversation."""
|
||||
mock_conversation = Mock()
|
||||
mock_conversation.name = "New Name"
|
||||
mock_rename.return_value = mock_conversation
|
||||
|
||||
result = ConversationService.rename(
|
||||
app_model=Mock(spec=App),
|
||||
conversation_id="conv_123",
|
||||
user=Mock(spec=EndUser),
|
||||
name="New Name",
|
||||
auto_generate=False,
|
||||
)
|
||||
|
||||
assert result.name == "New Name"
|
||||
|
||||
|
||||
class TestConversationPayloadsController:
|
||||
def test_rename_requires_name(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(auto_generate=False, name="")
|
||||
|
||||
def test_variables_query_invalid_name(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConversationVariablesQuery(variable_name="bad;")
|
||||
|
||||
|
||||
class TestConversationApiController:
|
||||
def test_list_not_chat(self, app) -> None:
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _SessionStub:
|
||||
def __enter__(self):
|
||||
return SimpleNamespace()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"pagination_by_last_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(LastConversationNotExistsError()),
|
||||
)
|
||||
conversation_module = sys.modules["controllers.service_api.app.conversation"]
|
||||
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
||||
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations?last_id=00000000-0000-0000-0000-000000000001&limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestConversationDetailApiController:
|
||||
def test_delete_not_chat(self, app) -> None:
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationRenameApiController:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"rename",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationRenameApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/name",
|
||||
method="POST",
|
||||
json={"auto_generate": True},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationVariablesApiController:
|
||||
def test_not_chat(self, app) -> None:
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1/variables", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"get_conversational_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables?limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationVariableDetailApiController:
|
||||
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"update_conversation_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableTypeMismatchError("bad")),
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables/2",
|
||||
method="PUT",
|
||||
json={"value": "x"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(
|
||||
api,
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
|
||||
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"update_conversation_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables/2",
|
||||
method="PUT",
|
||||
json={"value": "x"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(
|
||||
api,
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
398
api/tests/unit_tests/controllers/service_api/app/test_file.py
Normal file
398
api/tests/unit_tests/controllers/service_api/app/test_file.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""
|
||||
Unit tests for Service API File controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- File upload validation
|
||||
- Error handling for file operations
|
||||
- FileService integration
|
||||
|
||||
Focus on:
|
||||
- File validation logic (size, type, filename)
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class TestFileResponse:
|
||||
"""Test suite for FileResponse Pydantic model."""
|
||||
|
||||
def test_file_response_has_required_fields(self):
|
||||
"""Test FileResponse model includes required fields."""
|
||||
# Verify the model exists and can be imported
|
||||
assert FileResponse is not None
|
||||
assert hasattr(FileResponse, "model_fields")
|
||||
|
||||
|
||||
class TestFileUploadErrors:
|
||||
"""Test file upload error types."""
|
||||
|
||||
def test_no_file_uploaded_error_can_be_raised(self):
|
||||
"""Test NoFileUploadedError can be raised."""
|
||||
error = NoFileUploadedError()
|
||||
assert error is not None
|
||||
|
||||
def test_too_many_files_error_can_be_raised(self):
|
||||
"""Test TooManyFilesError can be raised."""
|
||||
error = TooManyFilesError()
|
||||
assert error is not None
|
||||
|
||||
def test_unsupported_file_type_error_can_be_raised(self):
|
||||
"""Test UnsupportedFileTypeError can be raised."""
|
||||
error = UnsupportedFileTypeError()
|
||||
assert error is not None
|
||||
|
||||
def test_filename_not_exists_error_can_be_raised(self):
|
||||
"""Test FilenameNotExistsError can be raised."""
|
||||
error = FilenameNotExistsError()
|
||||
assert error is not None
|
||||
|
||||
def test_file_too_large_error_can_be_raised(self):
|
||||
"""Test FileTooLargeError can be raised."""
|
||||
error = FileTooLargeError("File exceeds maximum size")
|
||||
assert "File exceeds maximum size" in str(error) or error is not None
|
||||
|
||||
|
||||
class TestFileServiceErrors:
|
||||
"""Test FileService error types."""
|
||||
|
||||
def test_file_service_file_too_large_error_exists(self):
|
||||
"""Test FileTooLargeError from services exists."""
|
||||
import services.errors.file
|
||||
|
||||
error = services.errors.file.FileTooLargeError("File too large")
|
||||
assert isinstance(error, services.errors.file.FileTooLargeError)
|
||||
|
||||
def test_file_service_unsupported_file_type_error_exists(self):
|
||||
"""Test UnsupportedFileTypeError from services exists."""
|
||||
import services.errors.file
|
||||
|
||||
error = services.errors.file.UnsupportedFileTypeError()
|
||||
assert isinstance(error, services.errors.file.UnsupportedFileTypeError)
|
||||
|
||||
|
||||
class TestFileService:
|
||||
"""Test FileService interface and methods."""
|
||||
|
||||
def test_upload_file_method_exists(self):
|
||||
"""Test FileService.upload_file method exists."""
|
||||
assert hasattr(FileService, "upload_file")
|
||||
assert callable(FileService.upload_file)
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_returns_upload_file_object(self, mock_upload):
|
||||
"""Test upload_file returns an upload file object."""
|
||||
mock_file = Mock()
|
||||
mock_file.id = str(uuid.uuid4())
|
||||
mock_file.name = "test.pdf"
|
||||
mock_file.size = 1024
|
||||
mock_file.extension = "pdf"
|
||||
mock_file.mime_type = "application/pdf"
|
||||
mock_upload.return_value = mock_file
|
||||
|
||||
# Call the method directly without instantiation
|
||||
assert mock_file.name == "test.pdf"
|
||||
assert mock_file.extension == "pdf"
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_raises_file_too_large_error(self, mock_upload):
|
||||
"""Test upload_file raises FileTooLargeError."""
|
||||
import services.errors.file
|
||||
|
||||
mock_upload.side_effect = services.errors.file.FileTooLargeError("File exceeds 15MB limit")
|
||||
|
||||
# Verify error type exists
|
||||
with pytest.raises(services.errors.file.FileTooLargeError):
|
||||
mock_upload(Mock(), Mock(), "user_id")
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_raises_unsupported_file_type_error(self, mock_upload):
|
||||
"""Test upload_file raises UnsupportedFileTypeError."""
|
||||
import services.errors.file
|
||||
|
||||
mock_upload.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
# Verify error type exists
|
||||
with pytest.raises(services.errors.file.UnsupportedFileTypeError):
|
||||
mock_upload(Mock(), Mock(), "user_id")
|
||||
|
||||
|
||||
class TestFileValidation:
|
||||
"""Test file validation patterns."""
|
||||
|
||||
def test_valid_image_mimetype(self):
|
||||
"""Test common image MIME types."""
|
||||
valid_mimetypes = ["image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"]
|
||||
for mimetype in valid_mimetypes:
|
||||
assert mimetype.startswith("image/")
|
||||
|
||||
def test_valid_document_mimetype(self):
|
||||
"""Test common document MIME types."""
|
||||
valid_mimetypes = [
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
]
|
||||
for mimetype in valid_mimetypes:
|
||||
assert mimetype is not None
|
||||
assert len(mimetype) > 0
|
||||
|
||||
def test_filename_has_extension(self):
|
||||
"""Test filename validation for extension presence."""
|
||||
valid_filenames = ["document.pdf", "image.png", "data.csv", "report.docx"]
|
||||
for filename in valid_filenames:
|
||||
assert "." in filename
|
||||
parts = filename.rsplit(".", 1)
|
||||
assert len(parts) == 2
|
||||
assert len(parts[1]) > 0 # Extension exists
|
||||
|
||||
def test_filename_without_extension_is_invalid(self):
|
||||
"""Test that filename without extension can be detected."""
|
||||
filename = "noextension"
|
||||
assert "." not in filename
|
||||
|
||||
|
||||
class TestFileUploadResponse:
|
||||
"""Test file upload response structure."""
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_response_structure(self, mock_upload):
|
||||
"""Test upload response has expected structure."""
|
||||
mock_file = Mock()
|
||||
mock_file.id = str(uuid.uuid4())
|
||||
mock_file.name = "test.pdf"
|
||||
mock_file.size = 2048
|
||||
mock_file.extension = "pdf"
|
||||
mock_file.mime_type = "application/pdf"
|
||||
mock_file.created_by = str(uuid.uuid4())
|
||||
mock_file.created_at = Mock()
|
||||
mock_upload.return_value = mock_file
|
||||
|
||||
# Verify expected fields exist on mock
|
||||
assert hasattr(mock_file, "id")
|
||||
assert hasattr(mock_file, "name")
|
||||
assert hasattr(mock_file, "size")
|
||||
assert hasattr(mock_file, "extension")
|
||||
assert hasattr(mock_file, "mime_type")
|
||||
assert hasattr(mock_file, "created_by")
|
||||
assert hasattr(mock_file, "created_at")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoint Tests
|
||||
#
|
||||
# ``FileApi.post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
|
||||
# which preserves ``__wrapped__`` via ``functools.wraps``. We call the
|
||||
# unwrapped method directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
from models import App
|
||||
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
from models import EndUser
|
||||
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid.uuid4())
|
||||
return user
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
"""Test suite for FileApi.post() endpoint.
|
||||
|
||||
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
|
||||
which preserves ``__wrapped__``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_file_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test successful file upload."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_upload = Mock()
|
||||
mock_upload.id = str(uuid.uuid4())
|
||||
mock_upload.name = "test.pdf"
|
||||
mock_upload.size = 1024
|
||||
mock_upload.extension = "pdf"
|
||||
mock_upload.mime_type = "application/pdf"
|
||||
mock_upload.created_by = str(mock_end_user.id)
|
||||
mock_upload.created_by_role = "end_user"
|
||||
mock_upload.created_at = 1700000000
|
||||
mock_upload.preview_url = None
|
||||
mock_upload.source_url = None
|
||||
mock_upload.original_url = None
|
||||
mock_upload.user_id = None
|
||||
mock_upload.tenant_id = None
|
||||
mock_upload.conversation_id = None
|
||||
mock_upload.file_key = None
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = mock_upload
|
||||
|
||||
data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
response, status = _unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_app_model,
|
||||
end_user=mock_end_user,
|
||||
)
|
||||
|
||||
assert status == 201
|
||||
mock_file_svc_cls.return_value.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
|
||||
"""Test NoFileUploadedError when no file in request."""
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data={},
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
|
||||
"""Test TooManyFilesError when multiple files uploaded."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
data = {
|
||||
"file": (BytesIO(b"content1"), "file1.pdf", "application/pdf"),
|
||||
"extra": (BytesIO(b"content2"), "file2.pdf", "application/pdf"),
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(TooManyFilesError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
|
||||
"""Test UnsupportedFileTypeError when file has no mimetype."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.bin", "")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_file_too_large(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test FileTooLargeError when file exceeds size limit."""
|
||||
from io import BytesIO
|
||||
|
||||
import services.errors.file
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
"File exceeds 15MB limit"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big content"), "big.pdf", "application/pdf")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(FileTooLargeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_unsupported_file_type(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test UnsupportedFileTypeError from FileService."""
|
||||
from io import BytesIO
|
||||
|
||||
import services.errors.file
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.xyz", "application/octet-stream")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
541
api/tests/unit_tests/controllers/service_api/app/test_message.py
Normal file
541
api/tests/unit_tests/controllers/service_api/app/test_message.py
Normal file
@ -0,0 +1,541 @@
|
||||
"""
|
||||
Unit tests for Service API Message controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- MessageListQuery, MessageFeedbackPayload, FeedbackListQuery Pydantic models
|
||||
- App mode validation for message endpoints
|
||||
- MessageService integration
|
||||
- Error handling for message operations
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation
|
||||
- UUID normalization
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.app.message import (
|
||||
AppGetFeedbacksApi,
|
||||
FeedbackListQuery,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageListApi,
|
||||
MessageListQuery,
|
||||
MessageSuggestedApi,
|
||||
)
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
"""Test suite for MessageListQuery Pydantic model."""
|
||||
|
||||
def test_query_requires_conversation_id(self):
|
||||
"""Test conversation_id is required."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id)
|
||||
assert str(query.conversation_id) == conversation_id
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id)
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_first_id(self):
|
||||
"""Test query with first_id for pagination."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
first_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id, first_id=first_id)
|
||||
assert str(query.first_id) == first_id
|
||||
|
||||
def test_query_with_custom_limit(self):
|
||||
"""Test query with custom limit."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id, limit=50)
|
||||
assert query.limit == 50
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query respects limit boundaries."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
query_min = MessageListQuery(conversation_id=conversation_id, limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = MessageListQuery(conversation_id=conversation_id, limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
MessageListQuery(conversation_id=conversation_id, limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
MessageListQuery(conversation_id=conversation_id, limit=101)
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
"""Test suite for MessageFeedbackPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_defaults(self):
|
||||
"""Test payload with default values."""
|
||||
payload = MessageFeedbackPayload()
|
||||
assert payload.rating is None
|
||||
assert payload.content is None
|
||||
|
||||
def test_payload_with_like_rating(self):
|
||||
"""Test payload with like rating."""
|
||||
payload = MessageFeedbackPayload(rating="like")
|
||||
assert payload.rating == "like"
|
||||
|
||||
def test_payload_with_dislike_rating(self):
|
||||
"""Test payload with dislike rating."""
|
||||
payload = MessageFeedbackPayload(rating="dislike")
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
def test_payload_with_content_only(self):
|
||||
"""Test payload with content but no rating."""
|
||||
payload = MessageFeedbackPayload(content="This response was helpful")
|
||||
assert payload.content == "This response was helpful"
|
||||
assert payload.rating is None
|
||||
|
||||
def test_payload_with_rating_and_content(self):
|
||||
"""Test payload with both rating and content."""
|
||||
payload = MessageFeedbackPayload(rating="like", content="Great answer, very detailed!")
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Great answer, very detailed!"
|
||||
|
||||
def test_payload_with_long_content(self):
|
||||
"""Test payload with long feedback content."""
|
||||
long_content = "A" * 1000
|
||||
payload = MessageFeedbackPayload(content=long_content)
|
||||
assert len(payload.content) == 1000
|
||||
|
||||
def test_payload_with_unicode_content(self):
|
||||
"""Test payload with unicode characters."""
|
||||
unicode_content = "很好的回答 👍 Отличный ответ"
|
||||
payload = MessageFeedbackPayload(content=unicode_content)
|
||||
assert payload.content == unicode_content
|
||||
|
||||
|
||||
class TestFeedbackListQuery:
|
||||
"""Test suite for FeedbackListQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = FeedbackListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_custom_pagination(self):
|
||||
"""Test query with custom page and limit."""
|
||||
query = FeedbackListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
def test_query_page_minimum(self):
|
||||
"""Test query page minimum validation."""
|
||||
query = FeedbackListQuery(page=1)
|
||||
assert query.page == 1
|
||||
|
||||
def test_query_rejects_page_below_minimum(self):
|
||||
"""Test query rejects page < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(page=0)
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query limit boundaries."""
|
||||
query_min = FeedbackListQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = FeedbackListQuery(limit=101)
|
||||
assert query_max.limit == 101 # Max is 101
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 101."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(limit=102)
|
||||
|
||||
|
||||
class TestMessageAppModeValidation:
|
||||
"""Test app mode validation for message endpoints."""
|
||||
|
||||
def test_chat_modes_are_valid_for_message_endpoints(self):
|
||||
"""Test that all chat modes are valid."""
|
||||
valid_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
for mode in valid_modes:
|
||||
assert mode in valid_modes
|
||||
|
||||
def test_completion_mode_is_invalid_for_message_endpoints(self):
|
||||
"""Test that COMPLETION mode is invalid."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.COMPLETION not in chat_modes
|
||||
|
||||
def test_workflow_mode_is_invalid_for_message_endpoints(self):
|
||||
"""Test that WORKFLOW mode is invalid."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.WORKFLOW not in chat_modes
|
||||
|
||||
def test_not_chat_app_error_can_be_raised(self):
|
||||
"""Test NotChatAppError can be raised."""
|
||||
error = NotChatAppError()
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestMessageErrorTypes:
|
||||
"""Test message-related error types."""
|
||||
|
||||
def test_message_not_exists_error_can_be_raised(self):
|
||||
"""Test MessageNotExistsError can be raised."""
|
||||
error = MessageNotExistsError()
|
||||
assert isinstance(error, MessageNotExistsError)
|
||||
|
||||
def test_first_message_not_exists_error_can_be_raised(self):
|
||||
"""Test FirstMessageNotExistsError can be raised."""
|
||||
error = FirstMessageNotExistsError()
|
||||
assert isinstance(error, FirstMessageNotExistsError)
|
||||
|
||||
def test_suggested_questions_after_answer_disabled_error_can_be_raised(self):
|
||||
"""Test SuggestedQuestionsAfterAnswerDisabledError can be raised."""
|
||||
error = SuggestedQuestionsAfterAnswerDisabledError()
|
||||
assert isinstance(error, SuggestedQuestionsAfterAnswerDisabledError)
|
||||
|
||||
|
||||
class TestMessageService:
|
||||
"""Test MessageService interface and methods."""
|
||||
|
||||
def test_pagination_by_first_id_method_exists(self):
|
||||
"""Test MessageService.pagination_by_first_id exists."""
|
||||
assert hasattr(MessageService, "pagination_by_first_id")
|
||||
assert callable(MessageService.pagination_by_first_id)
|
||||
|
||||
def test_create_feedback_method_exists(self):
|
||||
"""Test MessageService.create_feedback exists."""
|
||||
assert hasattr(MessageService, "create_feedback")
|
||||
assert callable(MessageService.create_feedback)
|
||||
|
||||
def test_get_all_messages_feedbacks_method_exists(self):
|
||||
"""Test MessageService.get_all_messages_feedbacks exists."""
|
||||
assert hasattr(MessageService, "get_all_messages_feedbacks")
|
||||
assert callable(MessageService.get_all_messages_feedbacks)
|
||||
|
||||
def test_get_suggested_questions_after_answer_method_exists(self):
|
||||
"""Test MessageService.get_suggested_questions_after_answer exists."""
|
||||
assert hasattr(MessageService, "get_suggested_questions_after_answer")
|
||||
assert callable(MessageService.get_suggested_questions_after_answer)
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_by_first_id_returns_pagination_result(self, mock_pagination):
|
||||
"""Test pagination_by_first_id returns expected format."""
|
||||
mock_result = Mock()
|
||||
mock_result.data = []
|
||||
mock_result.limit = 20
|
||||
mock_result.has_more = False
|
||||
mock_pagination.return_value = mock_result
|
||||
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
first_id=None,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert hasattr(result, "data")
|
||||
assert hasattr(result, "limit")
|
||||
assert hasattr(result, "has_more")
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_raises_conversation_not_exists_error(self, mock_pagination):
|
||||
"""Test pagination raises ConversationNotExistsError."""
|
||||
import services.errors.conversation
|
||||
|
||||
mock_pagination.side_effect = services.errors.conversation.ConversationNotExistsError()
|
||||
|
||||
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), conversation_id="invalid_id", first_id=None, limit=20
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_raises_first_message_not_exists_error(self, mock_pagination):
|
||||
"""Test pagination raises FirstMessageNotExistsError."""
|
||||
mock_pagination.side_effect = FirstMessageNotExistsError()
|
||||
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
first_id="invalid_first_id",
|
||||
limit=20,
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "create_feedback")
|
||||
def test_create_feedback_with_rating_and_content(self, mock_create_feedback):
|
||||
"""Test create_feedback with rating and content."""
|
||||
mock_create_feedback.return_value = None
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=Mock(spec=App),
|
||||
message_id=str(uuid.uuid4()),
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
content="Great response!",
|
||||
)
|
||||
|
||||
mock_create_feedback.assert_called_once()
|
||||
|
||||
@patch.object(MessageService, "create_feedback")
|
||||
def test_create_feedback_raises_message_not_exists_error(self, mock_create_feedback):
|
||||
"""Test create_feedback raises MessageNotExistsError."""
|
||||
mock_create_feedback.side_effect = MessageNotExistsError()
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.create_feedback(
|
||||
app_model=Mock(spec=App),
|
||||
message_id="invalid_message_id",
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
content=None,
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "get_all_messages_feedbacks")
|
||||
def test_get_all_messages_feedbacks_returns_list(self, mock_get_feedbacks):
|
||||
"""Test get_all_messages_feedbacks returns list of feedbacks."""
|
||||
mock_feedbacks = [
|
||||
{"message_id": str(uuid.uuid4()), "rating": "like"},
|
||||
{"message_id": str(uuid.uuid4()), "rating": "dislike"},
|
||||
]
|
||||
mock_get_feedbacks.return_value = mock_feedbacks
|
||||
|
||||
result = MessageService.get_all_messages_feedbacks(app_model=Mock(spec=App), page=1, limit=20)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["rating"] == "like"
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_returns_questions_list(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer returns list of questions."""
|
||||
mock_questions = ["What about this aspect?", "Can you elaborate on that?", "How does this relate to...?"]
|
||||
mock_get_questions.return_value = mock_questions
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], str)
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_raises_disabled_error(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer raises SuggestedQuestionsAfterAnswerDisabledError."""
|
||||
mock_get_questions.side_effect = SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_raises_message_not_exists_error(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer raises MessageNotExistsError."""
|
||||
mock_get_questions.side_effect = MessageNotExistsError()
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id="invalid_message_id", invoke_from=Mock()
|
||||
)
|
||||
|
||||
|
||||
class TestMessageListApi:
|
||||
def test_not_chat_app(self, app) -> None:
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages?conversation_id=cid", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"pagination_by_first_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages?conversation_id=00000000-0000-0000-0000-000000000001",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"pagination_by_first_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(FirstMessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages?conversation_id=00000000-0000-0000-0000-000000000001&first_id=00000000-0000-0000-0000-000000000002",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestMessageFeedbackApi:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"create_feedback",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageFeedbackApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace()
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages/m1/feedbacks",
|
||||
method="POST",
|
||||
json={"rating": "like", "content": "ok"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
|
||||
class TestAppGetFeedbacksApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
|
||||
|
||||
api = AppGetFeedbacksApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"data": ["f1"]}
|
||||
|
||||
|
||||
class TestMessageSuggestedApi:
|
||||
def test_not_chat(self, app) -> None:
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(SuggestedQuestionsAfterAnswerDisabledError()),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: ["q1"],
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
assert response == {"result": "success", "data": ["q1"]}
|
||||
@ -0,0 +1,653 @@
|
||||
"""
|
||||
Unit tests for Service API Workflow controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- WorkflowRunPayload and WorkflowLogQuery Pydantic models
|
||||
- Workflow execution error handling
|
||||
- App mode validation for workflow endpoints
|
||||
- Workflow stop mechanism validation
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow import (
|
||||
AppQueueManager,
|
||||
DifyAPIRepositoryFactory,
|
||||
GraphEngineManager,
|
||||
WorkflowAppLogApi,
|
||||
WorkflowLogQuery,
|
||||
WorkflowRunApi,
|
||||
WorkflowRunByIdApi,
|
||||
WorkflowRunDetailApi,
|
||||
WorkflowRunPayload,
|
||||
WorkflowTaskStopApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
"""Test suite for WorkflowRunPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_inputs(self):
|
||||
"""Test payload with required inputs field."""
|
||||
payload = WorkflowRunPayload(inputs={"key": "value"})
|
||||
assert payload.inputs == {"key": "value"}
|
||||
assert payload.files is None
|
||||
assert payload.response_mode is None
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
files = [{"type": "image", "url": "http://example.com/img.png"}]
|
||||
payload = WorkflowRunPayload(inputs={"param1": "value1", "param2": 123}, files=files, response_mode="streaming")
|
||||
assert payload.inputs == {"param1": "value1", "param2": 123}
|
||||
assert payload.files == files
|
||||
assert payload.response_mode == "streaming"
|
||||
|
||||
def test_payload_response_mode_blocking(self):
|
||||
"""Test payload with blocking response mode."""
|
||||
payload = WorkflowRunPayload(inputs={}, response_mode="blocking")
|
||||
assert payload.response_mode == "blocking"
|
||||
|
||||
def test_payload_with_complex_inputs(self):
|
||||
"""Test payload with nested complex inputs."""
|
||||
complex_inputs = {
|
||||
"config": {"nested": {"value": 123}},
|
||||
"items": ["item1", "item2"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
payload = WorkflowRunPayload(inputs=complex_inputs)
|
||||
assert payload.inputs == complex_inputs
|
||||
|
||||
def test_payload_with_empty_inputs(self):
|
||||
"""Test payload with empty inputs dict."""
|
||||
payload = WorkflowRunPayload(inputs={})
|
||||
assert payload.inputs == {}
|
||||
|
||||
def test_payload_with_multiple_files(self):
|
||||
"""Test payload with multiple file attachments."""
|
||||
files = [
|
||||
{"type": "image", "url": "http://example.com/img1.png"},
|
||||
{"type": "document", "upload_file_id": "file_123"},
|
||||
{"type": "audio", "url": "http://example.com/audio.mp3"},
|
||||
]
|
||||
payload = WorkflowRunPayload(inputs={}, files=files)
|
||||
assert len(payload.files) == 3
|
||||
|
||||
|
||||
class TestWorkflowLogQuery:
|
||||
"""Test suite for WorkflowLogQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = WorkflowLogQuery()
|
||||
assert query.keyword is None
|
||||
assert query.status is None
|
||||
assert query.created_at__before is None
|
||||
assert query.created_at__after is None
|
||||
assert query.created_by_end_user_session_id is None
|
||||
assert query.created_by_account is None
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_all_filters(self):
|
||||
"""Test query with all filter fields populated."""
|
||||
query = WorkflowLogQuery(
|
||||
keyword="search term",
|
||||
status="succeeded",
|
||||
created_at__before="2024-01-15T10:00:00Z",
|
||||
created_at__after="2024-01-01T00:00:00Z",
|
||||
created_by_end_user_session_id="session_123",
|
||||
created_by_account="user@example.com",
|
||||
page=2,
|
||||
limit=50,
|
||||
)
|
||||
assert query.keyword == "search term"
|
||||
assert query.status == "succeeded"
|
||||
assert query.created_at__before == "2024-01-15T10:00:00Z"
|
||||
assert query.created_at__after == "2024-01-01T00:00:00Z"
|
||||
assert query.created_by_end_user_session_id == "session_123"
|
||||
assert query.created_by_account == "user@example.com"
|
||||
assert query.page == 2
|
||||
assert query.limit == 50
|
||||
|
||||
@pytest.mark.parametrize("status", ["succeeded", "failed", "stopped"])
|
||||
def test_query_valid_status_values(self, status):
|
||||
"""Test all valid status values."""
|
||||
query = WorkflowLogQuery(status=status)
|
||||
assert query.status == status
|
||||
|
||||
def test_query_pagination_limits(self):
|
||||
"""Test query pagination boundaries."""
|
||||
query_min_page = WorkflowLogQuery(page=1)
|
||||
assert query_min_page.page == 1
|
||||
|
||||
query_max_page = WorkflowLogQuery(page=99999)
|
||||
assert query_max_page.page == 99999
|
||||
|
||||
query_min_limit = WorkflowLogQuery(limit=1)
|
||||
assert query_min_limit.limit == 1
|
||||
|
||||
query_max_limit = WorkflowLogQuery(limit=100)
|
||||
assert query_max_limit.limit == 100
|
||||
|
||||
def test_query_rejects_page_below_minimum(self):
|
||||
"""Test query rejects page < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(page=0)
|
||||
|
||||
def test_query_rejects_page_above_maximum(self):
|
||||
"""Test query rejects page > 99999."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(page=100000)
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(limit=101)
|
||||
|
||||
def test_query_with_keyword_search(self):
|
||||
"""Test query with keyword filter."""
|
||||
query = WorkflowLogQuery(keyword="workflow execution")
|
||||
assert query.keyword == "workflow execution"
|
||||
|
||||
def test_query_with_date_filters(self):
|
||||
"""Test query with before/after date filters."""
|
||||
query = WorkflowLogQuery(created_at__before="2024-12-31T23:59:59Z", created_at__after="2024-01-01T00:00:00Z")
|
||||
assert query.created_at__before == "2024-12-31T23:59:59Z"
|
||||
assert query.created_at__after == "2024-01-01T00:00:00Z"
|
||||
|
||||
|
||||
class TestWorkflowAppService:
|
||||
"""Test WorkflowAppService interface."""
|
||||
|
||||
def test_service_exists(self):
|
||||
"""Test WorkflowAppService class exists."""
|
||||
service = WorkflowAppService()
|
||||
assert service is not None
|
||||
|
||||
def test_get_paginate_workflow_app_logs_method_exists(self):
|
||||
"""Test get_paginate_workflow_app_logs method exists."""
|
||||
assert hasattr(WorkflowAppService, "get_paginate_workflow_app_logs")
|
||||
assert callable(WorkflowAppService.get_paginate_workflow_app_logs)
|
||||
|
||||
@patch.object(WorkflowAppService, "get_paginate_workflow_app_logs")
|
||||
def test_get_paginate_workflow_app_logs_returns_pagination(self, mock_get_logs):
|
||||
"""Test get_paginate_workflow_app_logs returns paginated result."""
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.data = []
|
||||
mock_pagination.page = 1
|
||||
mock_pagination.limit = 20
|
||||
mock_pagination.total = 0
|
||||
mock_get_logs.return_value = mock_pagination
|
||||
|
||||
service = WorkflowAppService()
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=Mock(),
|
||||
app_model=Mock(spec=App),
|
||||
keyword=None,
|
||||
status=None,
|
||||
created_at_before=None,
|
||||
created_at_after=None,
|
||||
page=1,
|
||||
limit=20,
|
||||
created_by_end_user_session_id=None,
|
||||
created_by_account=None,
|
||||
)
|
||||
|
||||
assert result.page == 1
|
||||
assert result.limit == 20
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_succeeded_status_exists(self):
|
||||
"""Test succeeded status value exists."""
|
||||
status = WorkflowExecutionStatus("succeeded")
|
||||
assert status.value == "succeeded"
|
||||
|
||||
def test_failed_status_exists(self):
|
||||
"""Test failed status value exists."""
|
||||
status = WorkflowExecutionStatus("failed")
|
||||
assert status.value == "failed"
|
||||
|
||||
def test_stopped_status_exists(self):
|
||||
"""Test stopped status value exists."""
|
||||
status = WorkflowExecutionStatus("stopped")
|
||||
assert status.value == "stopped"
|
||||
|
||||
|
||||
class TestAppGenerateServiceWorkflow:
|
||||
"""Test AppGenerateService workflow integration."""
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_accepts_workflow_args(self, mock_generate):
|
||||
"""Test generate accepts workflow-specific args."""
|
||||
mock_generate.return_value = {"result": "success"}
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"inputs": {"key": "value"}, "workflow_id": "workflow_123"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_workflow_not_found_error(self, mock_generate):
|
||||
"""Test generate raises WorkflowNotFoundError."""
|
||||
mock_generate.side_effect = WorkflowNotFoundError("Workflow not found")
|
||||
|
||||
with pytest.raises(WorkflowNotFoundError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"workflow_id": "invalid_id"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_is_draft_workflow_error(self, mock_generate):
|
||||
"""Test generate raises IsDraftWorkflowError."""
|
||||
mock_generate.side_effect = IsDraftWorkflowError("Workflow is draft")
|
||||
|
||||
with pytest.raises(IsDraftWorkflowError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"workflow_id": "draft_workflow"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_supports_streaming_mode(self, mock_generate):
|
||||
"""Test generate supports streaming response mode."""
|
||||
mock_stream = Mock()
|
||||
mock_generate.return_value = mock_stream
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"inputs": {}, "response_mode": "streaming"},
|
||||
invoke_from=Mock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == mock_stream
|
||||
|
||||
|
||||
class TestWorkflowStopMechanism:
|
||||
"""Test workflow stop mechanisms."""
|
||||
|
||||
def test_app_queue_manager_has_stop_flag_method(self):
|
||||
"""Test AppQueueManager has set_stop_flag_no_user_check method."""
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
||||
assert hasattr(AppQueueManager, "set_stop_flag_no_user_check")
|
||||
|
||||
def test_graph_engine_manager_has_send_stop_command(self):
|
||||
"""Test GraphEngineManager has send_stop_command method."""
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
|
||||
assert hasattr(GraphEngineManager, "send_stop_command")
|
||||
|
||||
|
||||
class TestWorkflowRunRepository:
|
||||
"""Test workflow run repository interface."""
|
||||
|
||||
def test_repository_factory_can_create_workflow_run_repository(self):
|
||||
"""Test DifyAPIRepositoryFactory can create workflow run repository."""
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
assert hasattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository")
|
||||
|
||||
@patch("repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository")
|
||||
def test_workflow_run_repository_get_by_id(self, mock_factory):
|
||||
"""Test workflow run repository get_workflow_run_by_id method."""
|
||||
mock_repo = Mock()
|
||||
mock_run = Mock()
|
||||
mock_run.id = str(uuid.uuid4())
|
||||
mock_run.status = "succeeded"
|
||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||
mock_factory.return_value = mock_repo
|
||||
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(Mock())
|
||||
|
||||
result = repo.get_workflow_run_by_id(tenant_id="tenant_123", app_id="app_456", run_id="run_789")
|
||||
|
||||
assert result.status == "succeeded"
|
||||
|
||||
|
||||
class TestWorkflowRunDetailApi:
|
||||
def test_not_workflow_app(self, app) -> None:
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
|
||||
with app.test_request_context("/workflows/run/1", method="GET"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, workflow_run_id="run")
|
||||
|
||||
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
run = SimpleNamespace(id="run")
|
||||
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
|
||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
||||
|
||||
assert handler(api, app_model=app_model, workflow_run_id="run") == run
|
||||
|
||||
|
||||
class TestWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app) -> None:
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(InvokeRateLimitError("slow")),
|
||||
)
|
||||
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestWorkflowRunByIdApi:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
|
||||
|
||||
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
|
||||
|
||||
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stop_mock = Mock()
|
||||
send_mock = Mock()
|
||||
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
|
||||
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
|
||||
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
stop_mock.assert_called_once_with("t1")
|
||||
send_mock.assert_called_once_with("t1")
|
||||
|
||||
|
||||
class TestWorkflowAppLogApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _SessionStub:
|
||||
def __enter__(self):
|
||||
return SimpleNamespace()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
||||
monkeypatch.setattr(
|
||||
WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
lambda *_args, **_kwargs: {"items": [], "total": 0},
|
||||
)
|
||||
|
||||
api = WorkflowAppLogApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/workflows/logs", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoint Tests
|
||||
#
|
||||
# ``WorkflowRunDetailApi``, ``WorkflowTaskStopApi``, and
|
||||
# ``WorkflowAppLogApi`` use ``@validate_app_token`` which preserves
|
||||
# ``__wrapped__`` via ``functools.wraps``. We call the unwrapped method
|
||||
# directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_app():
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
return app
|
||||
|
||||
|
||||
class TestWorkflowRunDetailApiGet:
|
||||
"""Test suite for WorkflowRunDetailApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
|
||||
and ``@service_api_ns.marshal_with``. We call the unwrapped method
|
||||
directly; ``marshal_with`` is a no-op when calling directly.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_run_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_repo_factory,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow run detail retrieval."""
|
||||
mock_run = Mock()
|
||||
mock_run.id = "run-1"
|
||||
mock_run.status = "succeeded"
|
||||
mock_repo = Mock()
|
||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowRunDetailApi
|
||||
|
||||
with app.test_request_context(
|
||||
f"/workflows/run/{mock_run.id}",
|
||||
method="GET",
|
||||
):
|
||||
api = WorkflowRunDetailApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||
|
||||
assert result == mock_run
|
||||
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
|
||||
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
|
||||
from controllers.service_api.app.workflow import WorkflowRunDetailApi
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.mode = AppMode.CHAT.value
|
||||
|
||||
with app.test_request_context("/workflows/run/run-1", method="GET"):
|
||||
api = WorkflowRunDetailApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
|
||||
|
||||
|
||||
class TestWorkflowTaskStopApiPost:
|
||||
"""Test suite for WorkflowTaskStopApi.post() endpoint.
|
||||
|
||||
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.GraphEngineManager")
|
||||
@patch("controllers.service_api.app.workflow.AppQueueManager")
|
||||
def test_stop_workflow_task_success(
|
||||
self,
|
||||
mock_queue_mgr,
|
||||
mock_graph_mgr,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow task stop."""
|
||||
from controllers.service_api.app.workflow import WorkflowTaskStopApi
|
||||
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
result = _unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_workflow_app,
|
||||
end_user=Mock(),
|
||||
task_id="task-1",
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
|
||||
|
||||
def test_stop_workflow_task_wrong_app_mode(self, app):
|
||||
"""Test NotWorkflowAppError when app mode is not workflow."""
|
||||
from controllers.service_api.app.workflow import WorkflowTaskStopApi
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.mode = AppMode.COMPLETION.value
|
||||
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
|
||||
|
||||
|
||||
class TestWorkflowAppLogApiGet:
|
||||
"""Test suite for WorkflowAppLogApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` and
|
||||
``@service_api_ns.marshal_with``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.WorkflowAppService")
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_logs_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_wf_svc_cls,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow log retrieval."""
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.data = []
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
||||
mock_wf_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
# Mock Session context manager
|
||||
mock_session = Mock()
|
||||
mock_db.engine = Mock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=False)
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowAppLogApi
|
||||
|
||||
with app.test_request_context(
|
||||
"/workflows/logs?page=1&limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session):
|
||||
api = WorkflowAppLogApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||
|
||||
assert result == mock_pagination
|
||||
218
api/tests/unit_tests/controllers/service_api/conftest.py
Normal file
218
api/tests/unit_tests/controllers/service_api/conftest.py
Normal file
@ -0,0 +1,218 @@
|
||||
"""
|
||||
Shared fixtures for Service API controller tests.
|
||||
|
||||
This module provides reusable fixtures for mocking authentication,
|
||||
database interactions, and common test data patterns used across
|
||||
Service API controller tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create Flask test application with proper configuration."""
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_id():
|
||||
"""Generate a consistent tenant ID for test sessions."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_id():
|
||||
"""Generate a consistent app ID for test sessions."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user(mock_tenant_id):
|
||||
"""Create a mock EndUser model with required attributes."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid.uuid4())
|
||||
user.external_user_id = f"external_{uuid.uuid4().hex[:8]}"
|
||||
user.tenant_id = mock_tenant_id
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(mock_app_id, mock_tenant_id):
|
||||
"""Create a mock App model with all required attributes for API testing."""
|
||||
app = Mock(spec=App)
|
||||
app.id = mock_app_id
|
||||
app.tenant_id = mock_tenant_id
|
||||
app.name = "Test App"
|
||||
app.description = "A test application"
|
||||
app.mode = AppMode.CHAT
|
||||
app.author_name = "Test Author"
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
app.tags = []
|
||||
|
||||
# Mock workflow for workflow apps
|
||||
app.workflow = None
|
||||
app.app_model_config = None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant(mock_tenant_id):
|
||||
"""Create a mock Tenant model."""
|
||||
tenant = Mock()
|
||||
tenant.id = mock_tenant_id
|
||||
tenant.status = TenantStatus.NORMAL
|
||||
return tenant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Create a mock Account model."""
|
||||
account = Mock()
|
||||
account.id = str(uuid.uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_token(mock_app_id, mock_tenant_id):
|
||||
"""Create a mock API token for authentication tests."""
|
||||
token = Mock()
|
||||
token.app_id = mock_app_id
|
||||
token.tenant_id = mock_tenant_id
|
||||
token.token = f"test_token_{uuid.uuid4().hex[:8]}"
|
||||
token.type = "app"
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_api_token(mock_tenant_id):
|
||||
"""Create a mock API token for dataset endpoints."""
|
||||
token = Mock()
|
||||
token.tenant_id = mock_tenant_id
|
||||
token.token = f"dataset_token_{uuid.uuid4().hex[:8]}"
|
||||
token.type = "dataset"
|
||||
return token
|
||||
|
||||
|
||||
class AuthenticationMocker:
|
||||
"""
|
||||
Helper class to set up common authentication mocking patterns.
|
||||
|
||||
Usage:
|
||||
auth_mocker = AuthenticationMocker()
|
||||
with auth_mocker.mock_app_auth(mock_api_token, mock_app_model, mock_tenant):
|
||||
# Test code here
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None):
|
||||
"""Configure mock_db to return app and tenant in sequence."""
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
@staticmethod
|
||||
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
|
||||
"""Configure mock_db for dataset token authentication."""
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_query = mock_db.session.query.return_value
|
||||
target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value
|
||||
target_mock.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_mocker():
|
||||
"""Provide an AuthenticationMocker instance."""
|
||||
return AuthenticationMocker()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset():
|
||||
"""Create a mock Dataset model."""
|
||||
from models.dataset import Dataset
|
||||
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid.uuid4())
|
||||
dataset.tenant_id = str(uuid.uuid4())
|
||||
dataset.name = "Test Dataset"
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.embedding_model = None
|
||||
dataset.embedding_model_provider = None
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
"""Create a mock Document model."""
|
||||
from models.dataset import Document
|
||||
|
||||
document = Mock(spec=Document)
|
||||
document.id = str(uuid.uuid4())
|
||||
document.dataset_id = str(uuid.uuid4())
|
||||
document.tenant_id = str(uuid.uuid4())
|
||||
document.name = "test_document.txt"
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.doc_form = "text_model"
|
||||
return document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment():
|
||||
"""Create a mock DocumentSegment model."""
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = str(uuid.uuid4())
|
||||
segment.dataset_id = str(uuid.uuid4())
|
||||
segment.tenant_id = str(uuid.uuid4())
|
||||
segment.content = "Test segment content"
|
||||
segment.word_count = 3
|
||||
segment.position = 1
|
||||
segment.enabled = True
|
||||
segment.status = "completed"
|
||||
return segment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_child_chunk():
|
||||
"""Create a mock ChildChunk model."""
|
||||
from models.dataset import ChildChunk
|
||||
|
||||
child_chunk = Mock(spec=ChildChunk)
|
||||
child_chunk.id = str(uuid.uuid4())
|
||||
child_chunk.segment_id = str(uuid.uuid4())
|
||||
child_chunk.tenant_id = str(uuid.uuid4())
|
||||
child_chunk.content = "Test child chunk content"
|
||||
return child_chunk
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
"""Walk ``__wrapped__`` chain to get the original function."""
|
||||
fn = method
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
return fn
|
||||
@ -0,0 +1,633 @@
|
||||
"""
|
||||
Unit tests for Service API RAG Pipeline Workflow controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- DatasourceNodeRunPayload Pydantic model
|
||||
- PipelineRunApiEntity / DatasourceNodeRunApiEntity model validation
|
||||
- RAG pipeline service interfaces
|
||||
- File upload validation for pipelines
|
||||
- Endpoint tests for DatasourcePluginsApi, DatasourceNodeRunApi,
|
||||
PipelineRunApi, and KnowledgebasePipelineFileUploadApi
|
||||
|
||||
Strategy:
|
||||
- Endpoint methods on these resources have no billing decorators on the method
|
||||
itself. ``method_decorators = [validate_dataset_token]`` is only invoked by
|
||||
Flask-RESTx dispatch, not by direct calls, so we call methods directly.
|
||||
- Only ``KnowledgebasePipelineFileUploadApi.post`` touches ``db`` inline
|
||||
(via ``FileService(db.engine)``); the other endpoints delegate to services.
|
||||
"""
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow import (
|
||||
DatasourceNodeRunApi,
|
||||
DatasourceNodeRunPayload,
|
||||
DatasourcePluginsApi,
|
||||
KnowledgebasePipelineFileUploadApi,
|
||||
PipelineRunApi,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import (
|
||||
DatasourceNodeRunApiEntity,
|
||||
PipelineRunApiEntity,
|
||||
)
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class TestDatasourceNodeRunPayload:
|
||||
"""Test suite for DatasourceNodeRunPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_fields(self):
|
||||
"""Test payload with required fields."""
|
||||
payload = DatasourceNodeRunPayload(
|
||||
inputs={"key": "value"}, datasource_type="online_document", is_published=True
|
||||
)
|
||||
assert payload.inputs == {"key": "value"}
|
||||
assert payload.datasource_type == "online_document"
|
||||
assert payload.is_published is True
|
||||
assert payload.credential_id is None
|
||||
|
||||
def test_payload_with_credential_id(self):
|
||||
"""Test payload with optional credential_id."""
|
||||
payload = DatasourceNodeRunPayload(
|
||||
inputs={"url": "https://example.com"},
|
||||
datasource_type="online_document",
|
||||
credential_id="cred_123",
|
||||
is_published=False,
|
||||
)
|
||||
assert payload.credential_id == "cred_123"
|
||||
assert payload.is_published is False
|
||||
|
||||
def test_payload_with_complex_inputs(self):
|
||||
"""Test payload with complex nested inputs."""
|
||||
complex_inputs = {
|
||||
"config": {"url": "https://api.example.com", "headers": {"Authorization": "Bearer token"}},
|
||||
"parameters": {"limit": 100, "offset": 0},
|
||||
"options": ["opt1", "opt2"],
|
||||
}
|
||||
payload = DatasourceNodeRunPayload(inputs=complex_inputs, datasource_type="api", is_published=True)
|
||||
assert payload.inputs == complex_inputs
|
||||
|
||||
def test_payload_with_empty_inputs(self):
|
||||
"""Test payload with empty inputs dict."""
|
||||
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="local_file", is_published=True)
|
||||
assert payload.inputs == {}
|
||||
|
||||
@pytest.mark.parametrize("datasource_type", ["online_document", "local_file", "api", "database", "website"])
|
||||
def test_payload_common_datasource_types(self, datasource_type):
|
||||
"""Test payload with common datasource types."""
|
||||
payload = DatasourceNodeRunPayload(inputs={}, datasource_type=datasource_type, is_published=True)
|
||||
assert payload.datasource_type == datasource_type
|
||||
|
||||
|
||||
class TestPipelineErrors:
|
||||
"""Test pipeline-related error types."""
|
||||
|
||||
def test_pipeline_run_error_can_be_raised(self):
|
||||
"""Test PipelineRunError can be raised."""
|
||||
error = PipelineRunError(description="Pipeline execution failed")
|
||||
assert error is not None
|
||||
|
||||
def test_pipeline_run_error_with_description(self):
|
||||
"""Test PipelineRunError captures description."""
|
||||
error = PipelineRunError(description="Timeout during node execution")
|
||||
# The error should have the description attribute
|
||||
assert hasattr(error, "description")
|
||||
|
||||
|
||||
class TestFileUploadErrors:
|
||||
"""Test file upload error types for pipelines."""
|
||||
|
||||
def test_no_file_uploaded_error(self):
|
||||
"""Test NoFileUploadedError can be raised."""
|
||||
error = NoFileUploadedError()
|
||||
assert error is not None
|
||||
|
||||
def test_too_many_files_error(self):
|
||||
"""Test TooManyFilesError can be raised."""
|
||||
error = TooManyFilesError()
|
||||
assert error is not None
|
||||
|
||||
def test_filename_not_exists_error(self):
|
||||
"""Test FilenameNotExistsError can be raised."""
|
||||
error = FilenameNotExistsError()
|
||||
assert error is not None
|
||||
|
||||
def test_file_too_large_error(self):
|
||||
"""Test FileTooLargeError can be raised."""
|
||||
error = FileTooLargeError("File exceeds size limit")
|
||||
assert error is not None
|
||||
|
||||
def test_unsupported_file_type_error(self):
|
||||
"""Test UnsupportedFileTypeError can be raised."""
|
||||
error = UnsupportedFileTypeError()
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestRagPipelineService:
|
||||
"""Test RagPipelineService interface."""
|
||||
|
||||
def test_get_datasource_plugins_method_exists(self):
|
||||
"""Test RagPipelineService.get_datasource_plugins exists."""
|
||||
assert hasattr(RagPipelineService, "get_datasource_plugins")
|
||||
|
||||
def test_get_pipeline_method_exists(self):
|
||||
"""Test RagPipelineService.get_pipeline exists."""
|
||||
assert hasattr(RagPipelineService, "get_pipeline")
|
||||
|
||||
def test_run_datasource_workflow_node_method_exists(self):
|
||||
"""Test RagPipelineService.run_datasource_workflow_node exists."""
|
||||
assert hasattr(RagPipelineService, "run_datasource_workflow_node")
|
||||
|
||||
def test_get_pipeline_templates_method_exists(self):
|
||||
"""Test RagPipelineService.get_pipeline_templates exists."""
|
||||
assert hasattr(RagPipelineService, "get_pipeline_templates")
|
||||
|
||||
def test_get_pipeline_template_detail_method_exists(self):
|
||||
"""Test RagPipelineService.get_pipeline_template_detail exists."""
|
||||
assert hasattr(RagPipelineService, "get_pipeline_template_detail")
|
||||
|
||||
|
||||
class TestInvokeFrom:
|
||||
"""Test InvokeFrom enum for pipeline invocation."""
|
||||
|
||||
def test_published_pipeline_invoke_from(self):
|
||||
"""Test PUBLISHED_PIPELINE InvokeFrom value exists."""
|
||||
assert hasattr(InvokeFrom, "PUBLISHED_PIPELINE")
|
||||
|
||||
def test_debugger_invoke_from(self):
|
||||
"""Test DEBUGGER InvokeFrom value exists."""
|
||||
assert hasattr(InvokeFrom, "DEBUGGER")
|
||||
|
||||
|
||||
class TestPipelineResponseModes:
|
||||
"""Test pipeline response mode patterns."""
|
||||
|
||||
def test_streaming_mode(self):
|
||||
"""Test streaming response mode."""
|
||||
mode = "streaming"
|
||||
valid_modes = ["streaming", "blocking"]
|
||||
assert mode in valid_modes
|
||||
|
||||
def test_blocking_mode(self):
|
||||
"""Test blocking response mode."""
|
||||
mode = "blocking"
|
||||
valid_modes = ["streaming", "blocking"]
|
||||
assert mode in valid_modes
|
||||
|
||||
|
||||
class TestDatasourceTypes:
|
||||
"""Test common datasource types for pipelines."""
|
||||
|
||||
@pytest.mark.parametrize("ds_type", ["online_document", "local_file", "website", "api", "database"])
|
||||
def test_datasource_type_valid(self, ds_type):
|
||||
"""Test common datasource types are strings."""
|
||||
assert isinstance(ds_type, str)
|
||||
assert len(ds_type) > 0
|
||||
|
||||
|
||||
class TestPipelineFileUploadResponse:
|
||||
"""Test file upload response structure for pipelines."""
|
||||
|
||||
def test_upload_response_fields(self):
|
||||
"""Test expected fields in upload response."""
|
||||
expected_fields = ["id", "name", "size", "extension", "mime_type", "created_by", "created_at"]
|
||||
|
||||
# Create mock response
|
||||
mock_response = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": "document.pdf",
|
||||
"size": 1024,
|
||||
"extension": "pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"created_by": str(uuid.uuid4()),
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in mock_response
|
||||
|
||||
|
||||
class TestPipelineNodeExecution:
|
||||
"""Test pipeline node execution patterns."""
|
||||
|
||||
def test_node_id_is_string(self):
|
||||
"""Test node_id is a string identifier."""
|
||||
node_id = "node_abc123"
|
||||
assert isinstance(node_id, str)
|
||||
assert len(node_id) > 0
|
||||
|
||||
def test_pipeline_id_is_uuid(self):
|
||||
"""Test pipeline_id is a valid UUID string."""
|
||||
pipeline_id = str(uuid.uuid4())
|
||||
assert len(pipeline_id) == 36
|
||||
assert "-" in pipeline_id
|
||||
|
||||
|
||||
class TestCredentialHandling:
|
||||
"""Test credential handling patterns."""
|
||||
|
||||
def test_credential_id_is_optional(self):
|
||||
"""Test credential_id can be None."""
|
||||
payload = DatasourceNodeRunPayload(
|
||||
inputs={}, datasource_type="local_file", is_published=True, credential_id=None
|
||||
)
|
||||
assert payload.credential_id is None
|
||||
|
||||
def test_credential_id_can_be_provided(self):
|
||||
"""Test credential_id can be set."""
|
||||
payload = DatasourceNodeRunPayload(
|
||||
inputs={}, datasource_type="api", is_published=True, credential_id="cred_oauth_123"
|
||||
)
|
||||
assert payload.credential_id == "cred_oauth_123"
|
||||
|
||||
|
||||
class TestPublishedVsDraft:
|
||||
"""Test published vs draft pipeline patterns."""
|
||||
|
||||
def test_is_published_true(self):
|
||||
"""Test is_published=True for published pipelines."""
|
||||
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="online_document", is_published=True)
|
||||
assert payload.is_published is True
|
||||
|
||||
def test_is_published_false_for_draft(self):
|
||||
"""Test is_published=False for draft pipelines."""
|
||||
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="online_document", is_published=False)
|
||||
assert payload.is_published is False
|
||||
|
||||
|
||||
class TestPipelineInputVariables:
|
||||
"""Test pipeline input variable patterns."""
|
||||
|
||||
def test_inputs_as_dict(self):
|
||||
"""Test inputs are passed as dictionary."""
|
||||
inputs = {"url": "https://example.com/doc.pdf", "timeout": 30, "retry": True}
|
||||
payload = DatasourceNodeRunPayload(inputs=inputs, datasource_type="online_document", is_published=True)
|
||||
assert payload.inputs["url"] == "https://example.com/doc.pdf"
|
||||
assert payload.inputs["timeout"] == 30
|
||||
assert payload.inputs["retry"] is True
|
||||
|
||||
def test_inputs_with_list_values(self):
|
||||
"""Test inputs with list values."""
|
||||
inputs = {"urls": ["https://example.com/1", "https://example.com/2"], "tags": ["tag1", "tag2", "tag3"]}
|
||||
payload = DatasourceNodeRunPayload(inputs=inputs, datasource_type="online_document", is_published=True)
|
||||
assert len(payload.inputs["urls"]) == 2
|
||||
assert len(payload.inputs["tags"]) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PipelineRunApiEntity / DatasourceNodeRunApiEntity Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPipelineRunApiEntity:
|
||||
"""Test PipelineRunApiEntity Pydantic model."""
|
||||
|
||||
def test_entity_with_all_fields(self):
|
||||
"""Test entity with all required fields."""
|
||||
entity = PipelineRunApiEntity(
|
||||
inputs={"key": "value"},
|
||||
datasource_type="online_document",
|
||||
datasource_info_list=[{"url": "https://example.com"}],
|
||||
start_node_id="node_1",
|
||||
is_published=True,
|
||||
response_mode="streaming",
|
||||
)
|
||||
assert entity.datasource_type == "online_document"
|
||||
assert entity.response_mode == "streaming"
|
||||
assert entity.is_published is True
|
||||
|
||||
def test_entity_blocking_response_mode(self):
|
||||
"""Test entity with blocking response mode."""
|
||||
entity = PipelineRunApiEntity(
|
||||
inputs={},
|
||||
datasource_type="local_file",
|
||||
datasource_info_list=[],
|
||||
start_node_id="node_start",
|
||||
is_published=False,
|
||||
response_mode="blocking",
|
||||
)
|
||||
assert entity.response_mode == "blocking"
|
||||
assert entity.is_published is False
|
||||
|
||||
def test_entity_missing_required_field(self):
|
||||
"""Test entity raises on missing required field."""
|
||||
with pytest.raises(ValueError):
|
||||
PipelineRunApiEntity(
|
||||
inputs={},
|
||||
datasource_type="online_document",
|
||||
# missing datasource_info_list, start_node_id, etc.
|
||||
)
|
||||
|
||||
|
||||
class TestDatasourceNodeRunApiEntity:
|
||||
"""Test DatasourceNodeRunApiEntity Pydantic model."""
|
||||
|
||||
def test_entity_with_all_fields(self):
|
||||
"""Test entity with all fields."""
|
||||
entity = DatasourceNodeRunApiEntity(
|
||||
pipeline_id=str(uuid.uuid4()),
|
||||
node_id="node_abc",
|
||||
inputs={"url": "https://example.com"},
|
||||
datasource_type="website",
|
||||
is_published=True,
|
||||
)
|
||||
assert entity.node_id == "node_abc"
|
||||
assert entity.credential_id is None
|
||||
|
||||
def test_entity_with_credential(self):
|
||||
"""Test entity with credential_id."""
|
||||
entity = DatasourceNodeRunApiEntity(
|
||||
pipeline_id=str(uuid.uuid4()),
|
||||
node_id="node_xyz",
|
||||
inputs={},
|
||||
datasource_type="api",
|
||||
credential_id="cred_123",
|
||||
is_published=False,
|
||||
)
|
||||
assert entity.credential_id == "cred_123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasourcePluginsApiGet:
|
||||
"""Tests for DatasourcePluginsApi.get().
|
||||
|
||||
The original source delegates directly to ``RagPipelineService`` without
|
||||
an inline dataset query, so no ``db`` patching is needed.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
|
||||
def test_get_plugins_success(self, mock_svc_cls, mock_db, app):
|
||||
"""Test successful retrieval of datasource plugins."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
dataset_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}]
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"):
|
||||
api = DatasourcePluginsApi()
|
||||
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert response == [{"name": "plugin_a"}]
|
||||
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=True
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
def test_get_plugins_not_found(self, mock_db, app):
|
||||
"""Test NotFound when dataset check fails."""
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource-plugins"):
|
||||
api = DatasourcePluginsApi()
|
||||
with pytest.raises(NotFound):
|
||||
api.get(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
|
||||
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app):
|
||||
"""Test empty plugin list."""
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_datasource_plugins.return_value = []
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource-plugins"):
|
||||
api = DatasourcePluginsApi()
|
||||
response, status = api.get(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
|
||||
|
||||
assert status == 200
|
||||
assert response == []
|
||||
|
||||
|
||||
class TestDatasourceNodeRunApiPost:
|
||||
"""Tests for DatasourceNodeRunApi.post().
|
||||
|
||||
The source asserts ``isinstance(current_user, Account)`` and delegates to
|
||||
``RagPipelineService`` and ``PipelineGenerator``, so we patch those plus
|
||||
``current_user`` and ``service_api_ns``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.helper")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.PipelineGenerator")
|
||||
@patch(
|
||||
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
|
||||
new_callable=lambda: Mock(spec=Account),
|
||||
)
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
|
||||
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app):
|
||||
"""Test successful datasource node run."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
dataset_id = str(uuid.uuid4())
|
||||
node_id = "node_abc"
|
||||
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
|
||||
mock_ns.payload = {
|
||||
"inputs": {"url": "https://example.com"},
|
||||
"datasource_type": "online_document",
|
||||
"is_published": True,
|
||||
}
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_pipeline.id = str(uuid.uuid4())
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_pipeline.return_value = mock_pipeline
|
||||
mock_svc_instance.run_datasource_workflow_node.return_value = iter(["event1"])
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
mock_gen.convert_to_event_stream.return_value = iter(["stream_event"])
|
||||
mock_helper.compact_generate_response.return_value = {"result": "ok"}
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/node_abc/run", method="POST"):
|
||||
api = DatasourceNodeRunApi()
|
||||
response = api.post(tenant_id=tenant_id, dataset_id=dataset_id, node_id=node_id)
|
||||
|
||||
assert response == {"result": "ok"}
|
||||
mock_svc_instance.get_pipeline.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
mock_svc_instance.get_pipeline.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
mock_svc_instance.run_datasource_workflow_node.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
def test_post_not_found(self, mock_db, app):
|
||||
"""Test NotFound when dataset check fails."""
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/n1/run", method="POST"):
|
||||
api = DatasourceNodeRunApi()
|
||||
with pytest.raises(NotFound):
|
||||
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()), node_id="n1")
|
||||
|
||||
@patch(
|
||||
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
|
||||
new="not_account",
|
||||
)
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
|
||||
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app):
|
||||
"""Test AssertionError when current_user is not an Account instance."""
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
mock_ns.payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "local_file",
|
||||
"is_published": True,
|
||||
}
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/n1/run", method="POST"):
|
||||
api = DatasourceNodeRunApi()
|
||||
with pytest.raises(AssertionError):
|
||||
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()), node_id="n1")
|
||||
|
||||
|
||||
class TestPipelineRunApiPost:
|
||||
"""Tests for PipelineRunApi.post()."""
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.helper")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService")
|
||||
@patch(
|
||||
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
|
||||
new_callable=lambda: Mock(spec=Account),
|
||||
)
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
|
||||
def test_post_success_streaming(
|
||||
self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen_svc, mock_helper, app
|
||||
):
|
||||
"""Test successful pipeline run with streaming response."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
dataset_id = str(uuid.uuid4())
|
||||
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
|
||||
mock_ns.payload = {
|
||||
"inputs": {"key": "val"},
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "node_1",
|
||||
"is_published": True,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_pipeline.return_value = mock_pipeline
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
mock_gen_svc.generate.return_value = {"result": "ok"}
|
||||
mock_helper.compact_generate_response.return_value = {"result": "ok"}
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
|
||||
api = PipelineRunApi()
|
||||
response = api.post(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
assert response == {"result": "ok"}
|
||||
mock_gen_svc.generate.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
def test_post_not_found(self, mock_db, app):
|
||||
"""Test NotFound when dataset check fails."""
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
|
||||
api = PipelineRunApi()
|
||||
with pytest.raises(NotFound):
|
||||
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
|
||||
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app):
|
||||
"""Test Forbidden when current_user is not an Account."""
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
mock_ns.payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "node_1",
|
||||
"is_published": True,
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
|
||||
api = PipelineRunApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
|
||||
|
||||
|
||||
class TestFileUploadApiPost:
|
||||
"""Tests for KnowledgebasePipelineFileUploadApi.post()."""
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app):
|
||||
"""Test successful file upload."""
|
||||
mock_current_user.__bool__ = Mock(return_value=True)
|
||||
|
||||
mock_upload = Mock()
|
||||
mock_upload.id = str(uuid.uuid4())
|
||||
mock_upload.name = "doc.pdf"
|
||||
mock_upload.size = 1024
|
||||
mock_upload.extension = "pdf"
|
||||
mock_upload.mime_type = "application/pdf"
|
||||
mock_upload.created_by = str(uuid.uuid4())
|
||||
mock_upload.created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
mock_file_svc_instance = Mock()
|
||||
mock_file_svc_instance.upload_file.return_value = mock_upload
|
||||
mock_file_svc_cls.return_value = mock_file_svc_instance
|
||||
|
||||
file_data = FileStorage(
|
||||
stream=io.BytesIO(b"fake pdf content"),
|
||||
filename="doc.pdf",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/datasets/pipeline/file-upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data={"file": file_data},
|
||||
):
|
||||
api = KnowledgebasePipelineFileUploadApi()
|
||||
response, status = api.post(tenant_id=str(uuid.uuid4()))
|
||||
|
||||
assert status == 201
|
||||
assert response["name"] == "doc.pdf"
|
||||
assert response["extension"] == "pdf"
|
||||
|
||||
def test_upload_no_file(self, app):
|
||||
"""Test error when no file is uploaded."""
|
||||
with app.test_request_context(
|
||||
"/datasets/pipeline/file-upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
api = KnowledgebasePipelineFileUploadApi()
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
api.post(tenant_id=str(uuid.uuid4()))
|
||||
1521
api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py
Normal file
1521
api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,205 @@
|
||||
"""
|
||||
Unit tests for Service API HitTesting controller.
|
||||
|
||||
Tests coverage for:
|
||||
- HitTestingPayload Pydantic model validation
|
||||
- HitTestingApi endpoint (success and error paths via direct method calls)
|
||||
|
||||
Strategy:
|
||||
- ``HitTestingApi.post`` is decorated with ``@cloud_edition_billing_rate_limit_check``
|
||||
which preserves ``__wrapped__``. We call ``post.__wrapped__(self, ...)`` to skip
|
||||
the billing decorator and test the business logic directly.
|
||||
- Base-class methods (``get_and_validate_dataset``, ``perform_hit_testing``) read
|
||||
``current_user`` from ``controllers.console.datasets.hit_testing_base``, so we
|
||||
patch it there.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
|
||||
from models.account import Account
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HitTestingPayload Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHitTestingPayload:
|
||||
"""Test suite for HitTestingPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_query(self):
|
||||
"""Test payload with required query field."""
|
||||
payload = HitTestingPayload(query="test query")
|
||||
assert payload.query == "test query"
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all optional fields."""
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model={"top_k": 5},
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
assert payload.query == "test query"
|
||||
assert payload.retrieval_model == {"top_k": 5}
|
||||
assert payload.external_retrieval_model == {"provider": "openai"}
|
||||
assert payload.attachment_ids == ["att_1", "att_2"]
|
||||
|
||||
def test_payload_query_too_long(self):
|
||||
"""Test payload rejects query over 250 characters."""
|
||||
with pytest.raises(ValueError):
|
||||
HitTestingPayload(query="x" * 251)
|
||||
|
||||
def test_payload_query_at_max_length(self):
|
||||
"""Test payload accepts query at exactly 250 characters."""
|
||||
payload = HitTestingPayload(query="x" * 250)
|
||||
assert len(payload.query) == 250
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HitTestingApi Tests
|
||||
#
|
||||
# We use ``post.__wrapped__`` to bypass ``@cloud_edition_billing_rate_limit_check``
|
||||
# and call the underlying method directly.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHitTestingApiPost:
|
||||
"""Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator."""
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Test successful hit testing request."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "test query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
api = HitTestingApi()
|
||||
# Skip billing decorator via __wrapped__
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
assert response["query"] == "test query"
|
||||
mock_hit_svc.retrieve.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_with_retrieval_model(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Test hit testing with custom retrieval model."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
mock_ns.payload = {
|
||||
"query": "complex query",
|
||||
"retrieval_model": retrieval_model,
|
||||
"external_retrieval_model": {"provider": "custom"},
|
||||
}
|
||||
|
||||
with app.test_request_context():
|
||||
api = HitTestingApi()
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
assert response["query"] == "complex query"
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_dataset_not_found(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Test hit testing with non-existent dataset."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
api = HitTestingApi()
|
||||
with pytest.raises(NotFound):
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_no_dataset_permission(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Test hit testing when user lacks dataset permission."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(
|
||||
"Access denied"
|
||||
)
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
api = HitTestingApi()
|
||||
with pytest.raises(Forbidden):
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
@ -0,0 +1,534 @@
|
||||
"""
|
||||
Unit tests for Service API Metadata controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- DatasetMetadataCreateServiceApi (post, get)
|
||||
- DatasetMetadataServiceApi (patch, delete)
|
||||
- DatasetMetadataBuiltInFieldServiceApi (get)
|
||||
- DatasetMetadataBuiltInFieldActionServiceApi (post)
|
||||
- DocumentMetadataEditServiceApi (post)
|
||||
|
||||
Decorator strategy:
|
||||
- ``@cloud_edition_billing_rate_limit_check`` preserves ``__wrapped__``
|
||||
via ``functools.wraps`` → call the unwrapped method directly.
|
||||
- Methods without billing decorators → call directly; only patch ``db``,
|
||||
services, and ``current_user``.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.dataset.metadata import (
|
||||
DatasetMetadataBuiltInFieldActionServiceApi,
|
||||
DatasetMetadataBuiltInFieldServiceApi,
|
||||
DatasetMetadataCreateServiceApi,
|
||||
DatasetMetadataServiceApi,
|
||||
DocumentMetadataEditServiceApi,
|
||||
)
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant():
|
||||
tenant = Mock()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
return tenant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset():
|
||||
dataset = Mock()
|
||||
dataset.id = str(uuid.uuid4())
|
||||
return dataset
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetMetadataCreateServiceApi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetMetadataCreatePost:
|
||||
"""Tests for DatasetMetadataCreateServiceApi.post().
|
||||
|
||||
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``
|
||||
which preserves ``__wrapped__``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.marshal")
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_create_metadata_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
mock_marshal,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful metadata creation."""
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_metadata = Mock()
|
||||
mock_meta_svc.create_metadata.return_value = mock_metadata
|
||||
mock_marshal.return_value = {"id": "meta-1", "name": "Author"}
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata",
|
||||
method="POST",
|
||||
json={"type": "string", "name": "Author"},
|
||||
):
|
||||
api = DatasetMetadataCreateServiceApi()
|
||||
response, status = self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
|
||||
assert status == 201
|
||||
mock_meta_svc.create_metadata.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_create_metadata_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata",
|
||||
method="POST",
|
||||
json={"type": "string", "name": "Author"},
|
||||
):
|
||||
api = DatasetMetadataCreateServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetMetadataCreateGet:
|
||||
"""Tests for DatasetMetadataCreateServiceApi.get()."""
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_get_metadata_success(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful metadata list retrieval."""
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_meta_svc.get_dataset_metadatas.return_value = [{"id": "m1"}]
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetMetadataCreateServiceApi()
|
||||
response, status = api.get(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_get_metadata_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetMetadataCreateServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
api.get(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetMetadataServiceApi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetMetadataServiceApiPatch:
|
||||
"""Tests for DatasetMetadataServiceApi.patch().
|
||||
|
||||
``patch`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call_patch(api, **kwargs):
|
||||
return _unwrap(api.patch)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.marshal")
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_update_metadata_name_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
mock_marshal,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful metadata name update."""
|
||||
metadata_id = str(uuid.uuid4())
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_meta_svc.update_metadata_name.return_value = Mock()
|
||||
mock_marshal.return_value = {"id": metadata_id, "name": "New Name"}
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
|
||||
method="PATCH",
|
||||
json={"name": "New Name"},
|
||||
):
|
||||
api = DatasetMetadataServiceApi()
|
||||
response, status = self._call_patch(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
metadata_id=metadata_id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
mock_meta_svc.update_metadata_name.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_update_metadata_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
metadata_id = str(uuid.uuid4())
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
|
||||
method="PATCH",
|
||||
json={"name": "x"},
|
||||
):
|
||||
api = DatasetMetadataServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
self._call_patch(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
metadata_id=metadata_id,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetMetadataServiceApiDelete:
|
||||
"""Tests for DatasetMetadataServiceApi.delete().
|
||||
|
||||
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call_delete(api, **kwargs):
|
||||
return _unwrap(api.delete)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_delete_metadata_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful metadata deletion."""
|
||||
metadata_id = str(uuid.uuid4())
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_meta_svc.delete_metadata.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
|
||||
method="DELETE",
|
||||
):
|
||||
api = DatasetMetadataServiceApi()
|
||||
response = self._call_delete(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
metadata_id=metadata_id,
|
||||
)
|
||||
|
||||
assert response == ("", 204)
|
||||
mock_meta_svc.delete_metadata.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_delete_metadata_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
metadata_id = str(uuid.uuid4())
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
|
||||
method="DELETE",
|
||||
):
|
||||
api = DatasetMetadataServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
self._call_delete(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
metadata_id=metadata_id,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetMetadataBuiltInFieldServiceApi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldGet:
|
||||
"""Tests for DatasetMetadataBuiltInFieldServiceApi.get()."""
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
def test_get_built_in_fields_success(
|
||||
self,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful built-in fields retrieval."""
|
||||
mock_meta_svc.get_built_in_fields.return_value = [
|
||||
{"name": "source", "type": "string"},
|
||||
]
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/built-in",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetMetadataBuiltInFieldServiceApi()
|
||||
response, status = api.get(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert "fields" in response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetMetadataBuiltInFieldActionServiceApi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldAction:
|
||||
"""Tests for DatasetMetadataBuiltInFieldActionServiceApi.post().
|
||||
|
||||
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_enable_built_in_field(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test enabling built-in metadata field."""
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/built-in/enable",
|
||||
method="POST",
|
||||
):
|
||||
api = DatasetMetadataBuiltInFieldActionServiceApi()
|
||||
response, status = self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
action="enable",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
mock_meta_svc.enable_built_in_field.assert_called_once_with(mock_dataset)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_disable_built_in_field(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test disabling built-in metadata field."""
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/built-in/disable",
|
||||
method="POST",
|
||||
):
|
||||
api = DatasetMetadataBuiltInFieldActionServiceApi()
|
||||
response, status = self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
action="disable",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
mock_meta_svc.disable_built_in_field.assert_called_once_with(mock_dataset)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_action_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/metadata/built-in/enable",
|
||||
method="POST",
|
||||
):
|
||||
api = DatasetMetadataBuiltInFieldActionServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
action="enable",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentMetadataEditServiceApi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentMetadataEditPost:
|
||||
"""Tests for DocumentMetadataEditServiceApi.post().
|
||||
|
||||
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@patch("controllers.service_api.dataset.metadata.current_user")
|
||||
def test_update_documents_metadata_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_meta_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test successful documents metadata update."""
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_meta_svc.update_documents_metadata.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/metadata",
|
||||
method="POST",
|
||||
json={"operation_data": []},
|
||||
):
|
||||
api = DocumentMetadataEditServiceApi()
|
||||
response, status = self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
def test_update_documents_metadata_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/metadata",
|
||||
method="POST",
|
||||
json={"operation_data": []},
|
||||
):
|
||||
api = DocumentMetadataEditServiceApi()
|
||||
with pytest.raises(NotFound):
|
||||
self._call_post(
|
||||
api,
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
)
|
||||
69
api/tests/unit_tests/controllers/service_api/test_index.py
Normal file
69
api/tests/unit_tests/controllers/service_api/test_index.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Unit tests for Service API Index endpoint
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.service_api.index import IndexApi
|
||||
|
||||
|
||||
class TestIndexApi:
|
||||
"""Test suite for IndexApi resource."""
|
||||
|
||||
@patch("controllers.service_api.index.dify_config")
|
||||
def test_get_returns_api_info(self, mock_config, app):
|
||||
"""Test that GET returns API metadata with correct structure."""
|
||||
# Arrange
|
||||
mock_config.project.version = "1.0.0-test"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
|
||||
# Assert
|
||||
assert response["welcome"] == "Dify OpenAPI"
|
||||
assert response["api_version"] == "v1"
|
||||
assert response["server_version"] == "1.0.0-test"
|
||||
|
||||
def test_get_response_has_required_fields(self, app):
|
||||
"""Test that response contains all required fields."""
|
||||
# Arrange
|
||||
mock_config = MagicMock()
|
||||
mock_config.project.version = "1.11.4"
|
||||
|
||||
# Act
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
|
||||
# Assert
|
||||
assert "welcome" in response
|
||||
assert "api_version" in response
|
||||
assert "server_version" in response
|
||||
assert isinstance(response["welcome"], str)
|
||||
assert isinstance(response["api_version"], str)
|
||||
assert isinstance(response["server_version"], str)
|
||||
|
||||
@pytest.mark.parametrize("version", ["0.0.1", "1.0.0", "2.0.0-beta", "1.11.4"])
|
||||
def test_get_returns_correct_version(self, app, version):
|
||||
"""Test that server_version matches config version."""
|
||||
# Arrange
|
||||
mock_config = MagicMock()
|
||||
mock_config.project.version = version
|
||||
|
||||
# Act
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
|
||||
# Assert
|
||||
assert response["server_version"] == version
|
||||
270
api/tests/unit_tests/controllers/service_api/test_site.py
Normal file
270
api/tests/unit_tests/controllers/service_api/test_site.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""
|
||||
Unit tests for Service API Site controller
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.service_api.app.site import AppSiteApi
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
||||
|
||||
class TestAppSiteApi:
|
||||
"""Test suite for AppSiteApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model with tenant."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.id = app.tenant_id
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
app.tenant = mock_tenant
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_site(self):
|
||||
"""Create a mock Site model."""
|
||||
site = Mock(spec=Site)
|
||||
site.id = str(uuid.uuid4())
|
||||
site.app_id = str(uuid.uuid4())
|
||||
site.title = "Test Site"
|
||||
site.icon = "icon-url"
|
||||
site.icon_background = "#ffffff"
|
||||
site.description = "Site description"
|
||||
site.copyright = "Copyright 2024"
|
||||
site.privacy_policy = "Privacy policy text"
|
||||
site.custom_disclaimer = "Custom disclaimer"
|
||||
site.default_language = "en-US"
|
||||
site.prompt_public = True
|
||||
site.show_workflow_steps = True
|
||||
site.use_icon_as_answer_icon = False
|
||||
site.chat_color_theme = "light"
|
||||
site.chat_color_theme_inverted = False
|
||||
site.icon_type = "image"
|
||||
site.created_at = "2024-01-01T00:00:00"
|
||||
site.updated_at = "2024-01-01T00:00:00"
|
||||
return site
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_success(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_site,
|
||||
):
|
||||
"""Test successful retrieval of site configuration."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
# Mock wraps.db for authentication
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site.db for site query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["title"] == "Test Site"
|
||||
assert response["icon"] == "icon-url"
|
||||
assert response["description"] == "Site description"
|
||||
mock_db.session.query.assert_called_once_with(Site)
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_not_found(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
):
|
||||
"""Test that Forbidden is raised when site is not found."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query to return None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_tenant_archived(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_site,
|
||||
):
|
||||
"""Test that Forbidden is raised when tenant is archived."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
|
||||
# Set tenant status to archived AFTER authentication
|
||||
mock_app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_queries_by_app_id(
|
||||
self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that site is queried using the app model's id."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
mock_site = Mock(spec=Site)
|
||||
mock_site.id = str(uuid.uuid4())
|
||||
mock_site.app_id = mock_app_model.id
|
||||
mock_site.title = "Test Site"
|
||||
mock_site.icon = "icon-url"
|
||||
mock_site.icon_background = "#ffffff"
|
||||
mock_site.description = "Site description"
|
||||
mock_site.copyright = "Copyright 2024"
|
||||
mock_site.privacy_policy = "Privacy policy text"
|
||||
mock_site.custom_disclaimer = "Custom disclaimer"
|
||||
mock_site.default_language = "en-US"
|
||||
mock_site.prompt_public = True
|
||||
mock_site.show_workflow_steps = True
|
||||
mock_site.use_icon_as_answer_icon = False
|
||||
mock_site.chat_color_theme = "light"
|
||||
mock_site.chat_color_theme_inverted = False
|
||||
mock_site.icon_type = "image"
|
||||
mock_site.created_at = "2024-01-01T00:00:00"
|
||||
mock_site.updated_at = "2024-01-01T00:00:00"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
api.get()
|
||||
|
||||
# Assert
|
||||
# The query was executed successfully (site returned), which validates the correct query was made
|
||||
mock_db.session.query.assert_called_once_with(Site)
|
||||
550
api/tests/unit_tests/controllers/service_api/test_wraps.py
Normal file
550
api/tests/unit_tests/controllers/service_api/test_wraps.py
Normal file
@ -0,0 +1,550 @@
|
||||
"""
|
||||
Unit tests for Service API wraps (authentication decorators)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
FetchUserArg,
|
||||
WhereisUserArg,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
validate_and_get_api_token,
|
||||
validate_app_token,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
from models.model import ApiToken
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_tenant_query,
|
||||
setup_mock_tenant_account_query,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateAndGetApiToken:
|
||||
"""Test suite for validate_and_get_api_token function"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
def test_missing_authorization_header(self, app):
|
||||
"""Test that Unauthorized is raised when Authorization header is missing."""
|
||||
# Arrange
|
||||
with app.test_request_context("/", method="GET"):
|
||||
# No Authorization header
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
validate_and_get_api_token("app")
|
||||
assert "Authorization header must be provided" in str(exc_info.value)
|
||||
|
||||
def test_invalid_auth_scheme(self, app):
|
||||
"""Test that Unauthorized is raised when auth scheme is not Bearer."""
|
||||
# Arrange
|
||||
with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}):
|
||||
# Act & Assert
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
validate_and_get_api_token("app")
|
||||
assert "Authorization scheme must be 'Bearer'" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.wraps.record_token_usage")
|
||||
@patch("controllers.service_api.wraps.ApiTokenCache")
|
||||
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
|
||||
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
|
||||
"""Test that valid token returns the ApiToken object."""
|
||||
# Arrange
|
||||
mock_api_token = Mock(spec=ApiToken)
|
||||
mock_api_token.token = "valid_token_123"
|
||||
mock_api_token.type = "app"
|
||||
|
||||
mock_cache_instance = Mock()
|
||||
mock_cache_instance.get.return_value = None # Cache miss
|
||||
mock_cache_cls.get = mock_cache_instance.get
|
||||
mock_fetch_token.return_value = mock_api_token
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer valid_token_123"}):
|
||||
result = validate_and_get_api_token("app")
|
||||
|
||||
# Assert
|
||||
assert result == mock_api_token
|
||||
|
||||
@patch("controllers.service_api.wraps.record_token_usage")
|
||||
@patch("controllers.service_api.wraps.ApiTokenCache")
|
||||
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
|
||||
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
|
||||
"""Test that invalid token raises Unauthorized."""
|
||||
# Arrange
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
mock_cache_instance = Mock()
|
||||
mock_cache_instance.get.return_value = None # Cache miss
|
||||
mock_cache_cls.get = mock_cache_instance.get
|
||||
mock_fetch_token.side_effect = Unauthorized("Access token is invalid")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer invalid_token"}):
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
validate_and_get_api_token("app")
|
||||
assert "Access token is invalid" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestValidateAppToken:
|
||||
"""Test suite for validate_app_token decorator"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
def test_valid_app_token_allows_access(
|
||||
self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app
|
||||
):
|
||||
"""Test that valid app token allows access to decorated view."""
|
||||
# Arrange
|
||||
# Use standard Mock for login_manager to avoid AsyncMockMixin warnings
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
mock_api_token.tenant_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_app = Mock()
|
||||
mock_app.id = mock_api_token.app_id
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
mock_app.tenant_id = mock_api_token.tenant_id
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_tenant.id = mock_api_token.tenant_id
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.id = str(uuid.uuid4())
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
mock_account,
|
||||
]
|
||||
|
||||
# Mock the tenant owner query
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(app_model):
|
||||
return {"success": True, "app_id": app_model.id}
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["success"] is True
|
||||
assert result["app_id"] == mock_app.id
|
||||
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app):
|
||||
"""Test that Forbidden is raised when app no longer exists."""
|
||||
# Arrange
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
return {"success": True}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
protected_view()
|
||||
assert "no longer exists" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app):
|
||||
"""Test that Forbidden is raised when app status is abnormal."""
|
||||
# Arrange
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_app = Mock()
|
||||
mock_app.status = "abnormal"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
return {"success": True}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
protected_view()
|
||||
assert "status is abnormal" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app):
|
||||
"""Test that Forbidden is raised when app API is disabled."""
|
||||
# Arrange
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_app = Mock()
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = False
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
return {"success": True}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
protected_view()
|
||||
assert "API service has been disabled" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestCloudEditionBillingResourceCheck:
|
||||
"""Test suite for cloud_edition_billing_resource_check decorator"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app):
|
||||
"""Test that request is allowed when under resource limit."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.members.limit = 10
|
||||
mock_features.members.size = 5
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_resource_check("members", "app")
|
||||
def add_member():
|
||||
return "member_added"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
result = add_member()
|
||||
|
||||
# Assert
|
||||
assert result == "member_added"
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app):
|
||||
"""Test that Forbidden is raised when at resource limit."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.members.limit = 10
|
||||
mock_features.members.size = 10
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_resource_check("members", "app")
|
||||
def add_member():
|
||||
return "member_added"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
add_member()
|
||||
assert "members has reached the limit" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app):
|
||||
"""Test that request is allowed when billing is disabled."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_resource_check("members", "app")
|
||||
def add_member():
|
||||
return "member_added"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
result = add_member()
|
||||
|
||||
# Assert
|
||||
assert result == "member_added"
|
||||
|
||||
|
||||
class TestCloudEditionBillingKnowledgeLimitCheck:
|
||||
"""Test suite for cloud_edition_billing_knowledge_limit_check decorator"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app):
|
||||
"""Test that add_segment is rejected in SANDBOX plan."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
def add_segment():
|
||||
return "segment_added"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
add_segment()
|
||||
assert "upgrade to a paid plan" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app):
|
||||
"""Test that non-add_segment operations are allowed in SANDBOX."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_knowledge_limit_check("search", "dataset")
|
||||
def search():
|
||||
return "search_results"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
result = search()
|
||||
|
||||
# Assert
|
||||
assert result == "search_results"
|
||||
|
||||
|
||||
class TestCloudEditionBillingRateLimitCheck:
|
||||
"""Test suite for cloud_edition_billing_rate_limit_check decorator"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
|
||||
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app):
|
||||
"""Test that request is allowed when within rate limit."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = True
|
||||
mock_rate_limit.limit = 100
|
||||
mock_get_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
# Mock redis operations
|
||||
with patch("controllers.service_api.wraps.redis_client") as mock_redis:
|
||||
mock_redis.zcard.return_value = 50 # Under limit
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def knowledge_request():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
result = knowledge_request()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
mock_redis.zadd.assert_called_once()
|
||||
mock_redis.zremrangebyscore.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app):
|
||||
"""Test that Forbidden is raised when over rate limit."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = True
|
||||
mock_rate_limit.limit = 10
|
||||
mock_rate_limit.subscription_plan = "pro"
|
||||
mock_get_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
with patch("controllers.service_api.wraps.redis_client") as mock_redis:
|
||||
mock_redis.zcard.return_value = 15 # Over limit
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def knowledge_request():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(Forbidden) as exc_info:
|
||||
knowledge_request()
|
||||
assert "rate limit" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestValidateDatasetToken:
|
||||
"""Test suite for validate_dataset_token decorator"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app):
|
||||
"""Test that valid dataset token allows access."""
|
||||
# Arrange
|
||||
# Use standard Mock for login_manager
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.tenant_id = tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.id = tenant_id
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = str(uuid.uuid4())
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.id = mock_ta.account_id
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
|
||||
|
||||
@validate_dataset_token
|
||||
def protected_view(tenant_id):
|
||||
return {"success": True, "tenant_id": tenant_id}
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["success"] is True
|
||||
assert result["tenant_id"] == tenant_id
|
||||
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app):
|
||||
"""Test that NotFound is raised when dataset doesn't exist."""
|
||||
# Arrange
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.tenant_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
@validate_dataset_token
|
||||
def protected_view(dataset_id=None, **kwargs):
|
||||
return {"success": True}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/", method="GET"):
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
protected_view(dataset_id=str(uuid.uuid4()))
|
||||
assert "Dataset not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestFetchUserArg:
|
||||
"""Test suite for FetchUserArg model"""
|
||||
|
||||
def test_fetch_user_arg_defaults(self):
|
||||
"""Test FetchUserArg default values."""
|
||||
# Arrange & Act
|
||||
arg = FetchUserArg(fetch_from=WhereisUserArg.JSON)
|
||||
|
||||
# Assert
|
||||
assert arg.fetch_from == WhereisUserArg.JSON
|
||||
assert arg.required is False
|
||||
|
||||
def test_fetch_user_arg_required(self):
|
||||
"""Test FetchUserArg with required=True."""
|
||||
# Arrange & Act
|
||||
arg = FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)
|
||||
|
||||
# Assert
|
||||
assert arg.fetch_from == WhereisUserArg.QUERY
|
||||
assert arg.required is True
|
||||
|
||||
|
||||
class TestDatasetApiResource:
|
||||
"""Test suite for DatasetApiResource base class"""
|
||||
|
||||
def test_method_decorators_has_validate_dataset_token(self):
|
||||
"""Test that DatasetApiResource has validate_dataset_token in method_decorators."""
|
||||
# Assert
|
||||
assert validate_dataset_token in DatasetApiResource.method_decorators
|
||||
|
||||
def test_get_dataset_method_exists(self):
|
||||
"""Test that get_dataset method exists on DatasetApiResource."""
|
||||
# Assert
|
||||
assert hasattr(DatasetApiResource, "get_dataset")
|
||||
Loading…
Reference in New Issue
Block a user