test: improve unit tests for controllers.service_api (#32073)

Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
Dev Sharma 2026-02-25 12:15:50 +05:30 committed by GitHub
parent 212756c315
commit d773096146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 11279 additions and 2 deletions

View File

@ -3,7 +3,8 @@ from typing import Any
from flask import request
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
@ -17,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models import Account
from models.dataset import Pipeline
from models.dataset import Dataset, Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
@ -65,6 +66,12 @@ class DatasourcePluginsApi(DatasetApiResource):
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
@ -104,6 +111,12 @@ class DatasourceNodeRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
@ -161,6 +174,12 @@ class PipelineRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):

View File

@ -124,3 +124,38 @@ def _configure_session_factory(_unit_test_engine):
session_factory.get_session_maker()
except RuntimeError:
configure_session_factory(_unit_test_engine, expire_on_commit=False)
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
"""
Helper to set up the mock DB query chain for tenant/account authentication.
This configures the mock to return (tenant, account) for the join query used
by validate_app_token and validate_dataset_token decorators.
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
"""
query = mock_db.session.query.return_value
join_chain = query.join.return_value.join.return_value
where_chain = join_chain.where.return_value
where_chain.one_or_none.return_value = (mock_tenant, mock_account)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
"""
Helper to set up the mock DB query chain for dataset tenant authentication.
This configures the mock to return (tenant, tenant_account) for the where chain
query used by validate_dataset_token decorator.
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
"""
query = mock_db.session.query.return_value
where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value
where_chain.one_or_none.return_value = (mock_tenant, mock_ta)

View File

@ -0,0 +1,295 @@
"""
Unit tests for Service API Annotation controller.
Tests coverage for:
- AnnotationCreatePayload Pydantic model validation
- AnnotationReplyActionPayload Pydantic model validation
- Error patterns and validation logic
Note: API endpoint tests for annotation controllers are complex due to:
- @validate_app_token decorator requiring full Flask-SQLAlchemy setup
- @edit_permission_required decorator checking current_user permissions
- These are better covered by integration tests
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask_restx.api import HTTPStatus
from controllers.service_api.app.annotation import (
AnnotationCreatePayload,
AnnotationListApi,
AnnotationReplyActionApi,
AnnotationReplyActionPayload,
AnnotationReplyActionStatusApi,
AnnotationUpdateDeleteApi,
)
from extensions.ext_redis import redis_client
from models.model import App
from services.annotation_service import AppAnnotationService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
# ---------------------------------------------------------------------------
# Pydantic Model Tests
# ---------------------------------------------------------------------------
class TestAnnotationCreatePayload:
"""Test suite for AnnotationCreatePayload Pydantic model."""
def test_payload_with_question_and_answer(self):
"""Test payload with required fields."""
payload = AnnotationCreatePayload(
question="What is AI?",
answer="AI is artificial intelligence.",
)
assert payload.question == "What is AI?"
assert payload.answer == "AI is artificial intelligence."
def test_payload_with_unicode_content(self):
"""Test payload with unicode content."""
payload = AnnotationCreatePayload(
question="什么是人工智能?",
answer="人工智能是模拟人类智能的技术。",
)
assert payload.question == "什么是人工智能?"
def test_payload_with_special_characters(self):
"""Test payload with special characters."""
payload = AnnotationCreatePayload(
question="What is <b>AI</b>?",
answer="AI & ML are related fields with 100% growth!",
)
assert "<b>" in payload.question
class TestAnnotationReplyActionPayload:
"""Test suite for AnnotationReplyActionPayload Pydantic model."""
def test_payload_with_all_fields(self):
"""Test payload with all fields."""
payload = AnnotationReplyActionPayload(
score_threshold=0.8,
embedding_provider_name="openai",
embedding_model_name="text-embedding-ada-002",
)
assert payload.score_threshold == 0.8
assert payload.embedding_provider_name == "openai"
assert payload.embedding_model_name == "text-embedding-ada-002"
def test_payload_with_different_provider(self):
"""Test payload with different embedding provider."""
payload = AnnotationReplyActionPayload(
score_threshold=0.75,
embedding_provider_name="azure_openai",
embedding_model_name="text-embedding-3-small",
)
assert payload.embedding_provider_name == "azure_openai"
def test_payload_with_zero_threshold(self):
"""Test payload with zero score threshold."""
payload = AnnotationReplyActionPayload(
score_threshold=0.0,
embedding_provider_name="local",
embedding_model_name="default",
)
assert payload.score_threshold == 0.0
# ---------------------------------------------------------------------------
# Model and Error Pattern Tests
# ---------------------------------------------------------------------------
class TestAppModelPatterns:
"""Test App model patterns used by annotation controller."""
def test_app_model_has_required_fields(self):
"""Test App model has required fields for annotation operations."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
assert app.id is not None
assert app.status == "normal"
assert app.enable_api is True
def test_app_model_disabled_api(self):
"""Test app with disabled API access."""
app = Mock(spec=App)
app.enable_api = False
assert app.enable_api is False
def test_app_model_archived_status(self):
"""Test app with archived status."""
app = Mock(spec=App)
app.status = "archived"
assert app.status == "archived"
class TestAnnotationErrorPatterns:
"""Test annotation-related error handling patterns."""
def test_not_found_error_pattern(self):
"""Test NotFound error pattern used in annotation operations."""
from werkzeug.exceptions import NotFound
with pytest.raises(NotFound):
raise NotFound("Annotation not found.")
def test_forbidden_error_pattern(self):
"""Test Forbidden error pattern."""
from werkzeug.exceptions import Forbidden
with pytest.raises(Forbidden):
raise Forbidden("Permission denied.")
def test_value_error_for_job_not_found(self):
"""Test ValueError pattern for job not found."""
with pytest.raises(ValueError, match="does not exist"):
raise ValueError("The job does not exist.")
class TestAnnotationReplyActionApi:
def test_enable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
enable_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
"/apps/annotation-reply/enable",
method="POST",
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
):
response, status = handler(api, app_model=app_model, action="enable")
assert status == 200
enable_mock.assert_called_once()
def test_disable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
disable_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
"/apps/annotation-reply/disable",
method="POST",
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
):
response, status = handler(api, app_model=app_model, action="disable")
assert status == 200
disable_mock.assert_called_once()
class TestAnnotationReplyActionStatusApi:
def test_missing_job(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
with pytest.raises(ValueError):
handler(api, app_model=app_model, job_id="j1", action="enable")
def test_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
def _get(key):
if "error" in key:
return b"oops"
return b"error"
monkeypatch.setattr(redis_client, "get", _get)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
assert status == 200
assert response["job_status"] == "error"
assert response["error_msg"] == "oops"
class TestAnnotationListApi:
def test_get(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"get_annotation_list_by_app_id",
lambda *_args, **_kwargs: ([annotation], 1),
)
api = AnnotationListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations?page=1&limit=1", method="GET"):
response = handler(api, app_model=app_model)
assert response["total"] == 1
def test_create(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"insert_app_annotation_directly",
lambda *_args, **_kwargs: annotation,
)
api = AnnotationListApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
response, status = handler(api, app_model=app_model)
assert status == HTTPStatus.CREATED
assert response["question"] == "q"
class TestAnnotationUpdateDeleteApi:
def test_update_delete(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"update_app_annotation_directly",
lambda *_args, **_kwargs: annotation,
)
delete_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
api = AnnotationUpdateDeleteApi()
put_handler = _unwrap(api.put)
delete_handler = _unwrap(api.delete)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):
response = put_handler(api, app_model=app_model, annotation_id="1")
assert response["answer"] == "a"
with app.test_request_context("/apps/annotations/1", method="DELETE"):
response, status = delete_handler(api, app_model=app_model, annotation_id="1")
assert status == 204
delete_mock.assert_called_once()

View File

@ -0,0 +1,496 @@
"""
Unit tests for Service API App controllers
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi
from controllers.service_api.app.error import AppUnavailableError
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
class TestAppParameterApi:
"""Test suite for AppParameterApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.status = "normal"
app.enable_api = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_chat_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving parameters for a chat app."""
# Arrange
mock_current_app.login_manager = Mock()
mock_config = Mock()
mock_config.id = str(uuid.uuid4())
mock_config.to_dict.return_value = {
"user_input_form": [{"type": "text", "label": "Name", "variable": "name", "required": True}],
"suggested_questions": [],
}
mock_app_model.app_model_config = mock_config
mock_app_model.workflow = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
# Mock DB queries for app and tenant
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
response = api.get()
# Assert
assert "opening_statement" in response
assert "suggested_questions" in response
assert "user_input_form" in response
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_workflow_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving parameters for a workflow app."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_workflow = Mock()
mock_workflow.features_dict = {"suggested_questions": []}
mock_workflow.user_input_form.return_value = [{"type": "text", "label": "Input", "variable": "input"}]
mock_app_model.workflow = mock_workflow
mock_app_model.app_model_config = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
response = api.get()
# Assert
assert "user_input_form" in response
assert "opening_statement" in response
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_chat_config_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.app_model_config = None
mock_app_model.workflow = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
with pytest.raises(AppUnavailableError):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_workflow_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_app_model.workflow = None
mock_app_model.app_model_config = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
with pytest.raises(AppUnavailableError):
api.get()
class TestAppMetaApi:
"""Test suite for AppMetaApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app.AppService")
def test_get_app_meta(
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving app metadata via AppService."""
# Arrange
mock_current_app.login_manager = Mock()
mock_service_instance = Mock()
mock_service_instance.get_app_meta.return_value = {
"tool_icons": {},
"AgentIcons": {},
}
mock_app_service.return_value = mock_service_instance
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppMetaApi()
response = api.get()
# Assert
mock_service_instance.get_app_meta.assert_called_once_with(mock_app_model)
assert response == {"tool_icons": {}, "AgentIcons": {}}
class TestAppInfoApi:
"""Test suite for AppInfoApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model with all required attributes."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.name = "Test App"
app.description = "A test application"
app.mode = AppMode.CHAT
app.author_name = "Test Author"
app.status = "normal"
app.enable_api = True
# Mock tags relationship
mock_tag = Mock()
mock_tag.name = "test-tag"
app.tags = [mock_tag]
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving basic app information."""
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["name"] == "Test App"
assert response["description"] == "A test application"
assert response["tags"] == ["test-tag"]
assert response["mode"] == AppMode.CHAT
assert response["author_name"] == "Test Author"
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_multiple_tags(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app
):
"""Test retrieving app info with multiple tags."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "Multi Tag App"
mock_app.description = "App with multiple tags"
mock_app.mode = AppMode.WORKFLOW
mock_app.author_name = "Author"
mock_app.status = "normal"
mock_app.enable_api = True
tag1, tag2, tag3 = Mock(), Mock(), Mock()
tag1.name = "tag-one"
tag2.name = "tag-two"
tag3.name = "tag-three"
mock_app.tags = [tag1, tag2, tag3]
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["tags"] == ["tag-one", "tag-two", "tag-three"]
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
"""Test retrieving app info when app has no tags."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "No Tags App"
mock_app.description = "App without tags"
mock_app.mode = AppMode.COMPLETION
mock_app.author_name = "Author"
mock_app.tags = []
mock_app.status = "normal"
mock_app.enable_api = True
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["tags"] == []
@pytest.mark.parametrize(
"app_mode",
[AppMode.CHAT, AppMode.COMPLETION, AppMode.WORKFLOW, AppMode.ADVANCED_CHAT],
)
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_returns_correct_mode(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
):
"""Test that all app modes are correctly returned."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "Test"
mock_app.description = "Test"
mock_app.mode = app_mode
mock_app.author_name = "Test"
mock_app.tags = []
mock_app.status = "normal"
mock_app.enable_api = True
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["mode"] == app_mode

View File

@ -0,0 +1,298 @@
"""
Unit tests for Service API Audio controller.
Tests coverage for:
- TextToAudioPayload Pydantic model validation
- Error mapping patterns between service and API errors
- AudioService method interfaces
"""
import io
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
from controllers.service_api.app.audio import AudioApi, TextApi, TextToAudioPayload
from controllers.service_api.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
# ---------------------------------------------------------------------------
# Pydantic Model Tests
# ---------------------------------------------------------------------------
class TestTextToAudioPayload:
"""Test suite for TextToAudioPayload Pydantic model."""
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
payload = TextToAudioPayload(
message_id="msg_123",
voice="nova",
text="Hello, this is a test.",
streaming=False,
)
assert payload.message_id == "msg_123"
assert payload.voice == "nova"
assert payload.text == "Hello, this is a test."
assert payload.streaming is False
def test_payload_with_defaults(self):
"""Test payload with default values."""
payload = TextToAudioPayload()
assert payload.message_id is None
assert payload.voice is None
assert payload.text is None
assert payload.streaming is None
def test_payload_with_only_text(self):
"""Test payload with only text field."""
payload = TextToAudioPayload(text="Simple text to speech")
assert payload.text == "Simple text to speech"
assert payload.voice is None
assert payload.message_id is None
def test_payload_with_streaming_true(self):
"""Test payload with streaming enabled."""
payload = TextToAudioPayload(
text="Streaming test",
streaming=True,
)
assert payload.streaming is True
# ---------------------------------------------------------------------------
# AudioService Interface Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test AudioService method interfaces exist."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
# ---------------------------------------------------------------------------
# Audio Service Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test suite for AudioService interface methods."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
class TestServiceErrorTypes:
"""Test service error types used by audio controller."""
def test_no_audio_uploaded_service_error(self):
"""Test NoAudioUploadedServiceError exists."""
error = NoAudioUploadedServiceError()
assert error is not None
def test_audio_too_large_service_error(self):
"""Test AudioTooLargeServiceError with message."""
error = AudioTooLargeServiceError("File too large")
assert "File too large" in str(error)
def test_unsupported_audio_type_service_error(self):
"""Test UnsupportedAudioTypeServiceError exists."""
error = UnsupportedAudioTypeServiceError()
assert error is not None
def test_provider_not_support_speech_to_text_service_error(self):
"""Test ProviderNotSupportSpeechToTextServiceError exists."""
error = ProviderNotSupportSpeechToTextServiceError()
assert error is not None
# ---------------------------------------------------------------------------
# Mocked Behavior Tests
# ---------------------------------------------------------------------------
class TestAudioServiceMockedBehavior:
"""Test AudioService behavior with mocked methods."""
@pytest.fixture
def mock_app(self):
"""Create mock app model."""
from models.model import App
app = Mock(spec=App)
app.id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_file(self):
"""Create mock file upload."""
mock = Mock()
mock.filename = "test_audio.mp3"
mock.content_type = "audio/mpeg"
return mock
@patch.object(AudioService, "transcript_asr")
def test_transcript_asr_returns_response(self, mock_asr, mock_app, mock_file):
"""Test ASR transcription returns response dict."""
mock_response = {"text": "Transcribed text"}
mock_asr.return_value = mock_response
result = AudioService.transcript_asr(
app_model=mock_app,
file=mock_file,
end_user="user_123",
)
assert result["text"] == "Transcribed text"
@patch.object(AudioService, "transcript_tts")
def test_transcript_tts_returns_response(self, mock_tts, mock_app):
"""Test TTS transcription returns response."""
mock_response = {"audio": "base64_audio_data"}
mock_tts.return_value = mock_response
result = AudioService.transcript_tts(
app_model=mock_app,
text="Hello world",
voice="nova",
end_user="user_123",
message_id="msg_123",
)
assert result["audio"] == "base64_audio_data"
class TestAudioApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(api, app_model=app_model, end_user=end_user)
assert response == {"text": "ok"}
@pytest.mark.parametrize(
("exc", "expected"),
[
(AppModelConfigBrokenError(), AppUnavailableError),
(NoAudioUploadedServiceError(), NoAudioUploadedError),
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
(QuotaExceededError(), ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(api, app_model=app_model, end_user=end_user)
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user)
class TestTextApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
with app.test_request_context(
"/text-to-audio",
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(api, app_model=app_model, end_user=end_user)
assert response == {"audio": "ok"}
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
)
api = TextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
with pytest.raises(ProviderQuotaExceededError):
handler(api, app_model=app_model, end_user=end_user)

View File

@ -0,0 +1,524 @@
"""
Unit tests for Service API Completion controllers.
Tests coverage for:
- CompletionRequestPayload and ChatRequestPayload Pydantic models
- App mode validation logic
- Error mapping from service layer to HTTP errors
Focus on:
- Pydantic model validation (especially UUID normalization)
- Error types and their mappings
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api.app.completion import (
ChatApi,
ChatRequestPayload,
ChatStopApi,
CompletionApi,
CompletionRequestPayload,
CompletionStopApi,
)
from controllers.service_api.app.error import (
AppUnavailableError,
ConversationCompletedError,
NotChatAppError,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCompletionRequestPayload:
"""Test suite for CompletionRequestPayload Pydantic model."""
def test_payload_with_required_fields(self):
"""Test payload with only required inputs field."""
payload = CompletionRequestPayload(inputs={"name": "test"})
assert payload.inputs == {"name": "test"}
assert payload.query == ""
assert payload.files is None
assert payload.response_mode is None
assert payload.retriever_from == "dev"
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
payload = CompletionRequestPayload(
inputs={"user_input": "Hello"},
query="What is AI?",
files=[{"type": "image", "url": "http://example.com/image.png"}],
response_mode="streaming",
retriever_from="api",
)
assert payload.inputs == {"user_input": "Hello"}
assert payload.query == "What is AI?"
assert payload.files == [{"type": "image", "url": "http://example.com/image.png"}]
assert payload.response_mode == "streaming"
assert payload.retriever_from == "api"
def test_payload_response_mode_blocking(self):
"""Test payload with blocking response mode."""
payload = CompletionRequestPayload(inputs={}, response_mode="blocking")
assert payload.response_mode == "blocking"
def test_payload_empty_inputs(self):
"""Test payload with empty inputs dict."""
payload = CompletionRequestPayload(inputs={})
assert payload.inputs == {}
def test_payload_complex_inputs(self):
"""Test payload with complex nested inputs."""
complex_inputs = {
"user": {"name": "Alice", "age": 30},
"context": ["item1", "item2"],
"settings": {"theme": "dark", "notifications": True},
}
payload = CompletionRequestPayload(inputs=complex_inputs)
assert payload.inputs == complex_inputs
class TestChatRequestPayload:
"""Test suite for ChatRequestPayload Pydantic model."""
def test_payload_with_required_fields(self):
"""Test payload with required fields."""
payload = ChatRequestPayload(inputs={"key": "value"}, query="Hello")
assert payload.inputs == {"key": "value"}
assert payload.query == "Hello"
assert payload.conversation_id is None
assert payload.auto_generate_name is True
def test_payload_normalizes_valid_uuid_conversation_id(self):
"""Test that valid UUID conversation_id is normalized."""
valid_uuid = str(uuid.uuid4())
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=valid_uuid)
assert payload.conversation_id == valid_uuid
def test_payload_normalizes_empty_string_conversation_id_to_none(self):
"""Test that empty string conversation_id becomes None."""
payload = ChatRequestPayload(inputs={}, query="test", conversation_id="")
assert payload.conversation_id is None
def test_payload_normalizes_whitespace_conversation_id_to_none(self):
"""Test that whitespace-only conversation_id becomes None."""
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=" ")
assert payload.conversation_id is None
def test_payload_rejects_invalid_uuid_conversation_id(self):
"""Test that invalid UUID format raises ValueError."""
with pytest.raises(ValueError) as exc_info:
ChatRequestPayload(inputs={}, query="test", conversation_id="not-a-uuid")
assert "valid UUID" in str(exc_info.value)
def test_payload_with_workflow_id(self):
"""Test payload with workflow_id for advanced chat."""
payload = ChatRequestPayload(inputs={}, query="test", workflow_id="workflow_123")
assert payload.workflow_id == "workflow_123"
def test_payload_streaming_mode(self):
"""Test payload with streaming response mode."""
payload = ChatRequestPayload(inputs={}, query="test", response_mode="streaming")
assert payload.response_mode == "streaming"
def test_payload_auto_generate_name_false(self):
"""Test payload with auto_generate_name explicitly false."""
payload = ChatRequestPayload(inputs={}, query="test", auto_generate_name=False)
assert payload.auto_generate_name is False
def test_payload_with_files(self):
"""Test payload with file attachments."""
files = [
{"type": "image", "transfer_method": "remote_url", "url": "http://example.com/img.png"},
{"type": "document", "transfer_method": "local_file", "upload_file_id": "file_123"},
]
payload = ChatRequestPayload(inputs={}, query="test", files=files)
assert payload.files == files
assert len(payload.files) == 2
class TestCompletionErrorMappings:
"""Test error type mappings for completion endpoints."""
def test_conversation_not_exists_error_exists(self):
"""Test ConversationNotExistsError can be raised."""
error = services.errors.conversation.ConversationNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
def test_conversation_completed_error_exists(self):
"""Test ConversationCompletedError can be raised."""
error = services.errors.conversation.ConversationCompletedError()
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
api_error = ConversationCompletedError()
assert api_error is not None
def test_app_model_config_broken_error_exists(self):
"""Test AppModelConfigBrokenError can be raised."""
error = services.errors.app_model_config.AppModelConfigBrokenError()
assert isinstance(error, services.errors.app_model_config.AppModelConfigBrokenError)
api_error = AppUnavailableError()
assert api_error is not None
def test_workflow_not_found_error_exists(self):
"""Test WorkflowNotFoundError can be raised."""
error = WorkflowNotFoundError("Workflow not found")
assert isinstance(error, WorkflowNotFoundError)
def test_is_draft_workflow_error_exists(self):
"""Test IsDraftWorkflowError can be raised."""
error = IsDraftWorkflowError("Workflow is in draft state")
assert isinstance(error, IsDraftWorkflowError)
def test_workflow_id_format_error_exists(self):
"""Test WorkflowIdFormatError can be raised."""
error = WorkflowIdFormatError("Invalid workflow ID format")
assert isinstance(error, WorkflowIdFormatError)
def test_invoke_rate_limit_error_exists(self):
"""Test InvokeRateLimitError can be raised."""
error = InvokeRateLimitError("Rate limit exceeded")
assert isinstance(error, InvokeRateLimitError)
class TestAppModeValidation:
"""Test app mode validation logic patterns."""
def test_completion_mode_is_valid_for_completion_endpoint(self):
"""Test that COMPLETION mode is valid for completion endpoints."""
assert AppMode.COMPLETION == AppMode.COMPLETION
def test_chat_modes_are_distinct_from_completion(self):
"""Test that chat modes are distinct from completion mode."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.COMPLETION not in chat_modes
def test_workflow_mode_is_distinct_from_chat_modes(self):
"""Test that WORKFLOW mode is not a chat mode."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.WORKFLOW not in chat_modes
def test_not_chat_app_error_can_be_raised(self):
"""Test NotChatAppError can be raised for non-chat apps."""
error = NotChatAppError()
assert error is not None
def test_all_app_modes_are_defined(self):
"""Test that all expected app modes are defined."""
expected_modes = ["COMPLETION", "CHAT", "AGENT_CHAT", "ADVANCED_CHAT", "WORKFLOW", "CHANNEL", "RAG_PIPELINE"]
for mode_name in expected_modes:
assert hasattr(AppMode, mode_name), f"AppMode.{mode_name} should exist"
class TestAppGenerateService:
"""Test AppGenerateService integration patterns."""
def test_generate_method_exists(self):
"""Test that AppGenerateService.generate method exists."""
assert hasattr(AppGenerateService, "generate")
assert callable(AppGenerateService.generate)
@patch.object(AppGenerateService, "generate")
def test_generate_returns_response(self, mock_generate):
"""Test that generate returns expected response format."""
expected = {"answer": "Hello!"}
mock_generate.return_value = expected
result = AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={"query": "Hi"}, invoke_from=Mock(), streaming=False
)
assert result == expected
@patch.object(AppGenerateService, "generate")
def test_generate_raises_conversation_not_exists(self, mock_generate):
"""Test generate raises ConversationNotExistsError."""
mock_generate.side_effect = services.errors.conversation.ConversationNotExistsError()
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_quota_exceeded(self, mock_generate):
"""Test generate raises QuotaExceededError."""
mock_generate.side_effect = QuotaExceededError()
with pytest.raises(QuotaExceededError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_invoke_error(self, mock_generate):
"""Test generate raises InvokeError."""
mock_generate.side_effect = InvokeError("Model invocation failed")
with pytest.raises(InvokeError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
class TestCompletionControllerLogic:
"""Test CompletionApi and ChatApi controller logic directly."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
from flask import Flask
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
"""Test CompletionApi.post success path."""
from controllers.service_api.app.completion import CompletionApi
# Setup mocks
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION
mock_end_user = Mock(spec=EndUser)
payload_dict = {"inputs": {"text": "hello"}, "response_mode": "blocking"}
mock_service_api_ns.payload = payload_dict
mock_generate_service.generate.return_value = {"text": "response"}
with app.test_request_context():
# Helper for compact_generate_response logic check
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
mock_compact.return_value = {"text": "compacted"}
api = CompletionApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
assert response == {"text": "compacted"}
mock_generate_service.generate.assert_called_once()
@patch("controllers.service_api.app.completion.service_api_ns")
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
"""Test CompletionApi.post with wrong app mode."""
from controllers.service_api.app.completion import CompletionApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT # Wrong mode
mock_end_user = Mock(spec=EndUser)
with app.test_request_context():
with pytest.raises(AppUnavailableError):
CompletionApi().post.__wrapped__(CompletionApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
"""Test ChatApi.post success path."""
from controllers.service_api.app.completion import ChatApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT
mock_end_user = Mock(spec=EndUser)
payload_dict = {"inputs": {}, "query": "hello", "response_mode": "blocking"}
mock_service_api_ns.payload = payload_dict
mock_generate_service.generate.return_value = {"text": "response"}
with app.test_request_context():
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
mock_compact.return_value = {"text": "compacted"}
api = ChatApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
assert response == {"text": "compacted"}
@patch("controllers.service_api.app.completion.service_api_ns")
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
"""Test ChatApi.post with wrong app mode."""
from controllers.service_api.app.completion import ChatApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION # Wrong mode
mock_end_user = Mock(spec=EndUser)
with app.test_request_context():
with pytest.raises(NotChatAppError):
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.AppTaskService")
def test_completion_stop_api_success(self, mock_task_service, app):
"""Test CompletionStopApi.post success."""
from controllers.service_api.app.completion import CompletionStopApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION
mock_end_user = Mock(spec=EndUser)
mock_end_user.id = "user_id"
with app.test_request_context():
api = CompletionStopApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
assert response == ({"result": "success"}, 200)
mock_task_service.stop_task.assert_called_once()
@patch("controllers.service_api.app.completion.AppTaskService")
def test_chat_stop_api_success(self, mock_task_service, app):
"""Test ChatStopApi.post success."""
from controllers.service_api.app.completion import ChatStopApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT
mock_end_user = Mock(spec=EndUser)
mock_end_user.id = "user_id"
with app.test_request_context():
api = ChatStopApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
assert response == ({"result": "success"}, 200)
mock_task_service.stop_task.assert_called_once()
class TestChatRequestPayloadController:
def test_normalizes_conversation_id(self) -> None:
payload = ChatRequestPayload.model_validate(
{"inputs": {}, "query": "hi", "conversation_id": " ", "response_mode": "blocking"}
)
assert payload.conversation_id is None
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hi", "conversation_id": "bad-id"})
class TestCompletionApiController:
def test_wrong_mode(self, app) -> None:
api = CompletionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
with pytest.raises(AppUnavailableError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
api = CompletionApi()
handler = _unwrap(api.post)
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestCompletionStopApiController:
def test_wrong_mode(self, app) -> None:
api = CompletionStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/completion-messages/1/stop", method="POST"):
with pytest.raises(AppUnavailableError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
api = CompletionStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/completion-messages/1/stop", method="POST"):
response, status = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
assert status == 200
assert response == {"result": "success"}
class TestChatApiController:
def test_wrong_mode(self, app) -> None:
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_workflow_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
)
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
)
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user)
class TestChatStopApiController:
def test_wrong_mode(self, app) -> None:
api = ChatStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/chat-messages/1/stop", method="POST"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")

View File

@ -0,0 +1,597 @@
"""
Unit tests for Service API Conversation controllers.
Tests coverage for:
- ConversationListQuery, ConversationRenamePayload Pydantic models
- ConversationVariablesQuery with SQL injection prevention
- ConversationVariableUpdatePayload
- App mode validation for chat-only endpoints
Focus on:
- Pydantic model validation including security checks
- SQL injection prevention in variable name filtering
- Error types and mappings
"""
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api.app.conversation import (
ConversationApi,
ConversationDetailApi,
ConversationListQuery,
ConversationRenameApi,
ConversationRenamePayload,
ConversationVariableDetailApi,
ConversationVariablesApi,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
ConversationVariableTypeMismatchError,
LastConversationNotExistsError,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestConversationListQuery:
"""Test suite for ConversationListQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = ConversationListQuery()
assert query.last_id is None
assert query.limit == 20
assert query.sort_by == "-updated_at"
def test_query_with_last_id(self):
"""Test query with pagination last_id."""
last_id = str(uuid.uuid4())
query = ConversationListQuery(last_id=last_id)
assert str(query.last_id) == last_id
def test_query_limit_boundaries(self):
"""Test query respects limit boundaries."""
query_min = ConversationListQuery(limit=1)
assert query_min.limit == 1
query_max = ConversationListQuery(limit=100)
assert query_max.limit == 100
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
ConversationListQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
with pytest.raises(ValueError):
ConversationListQuery(limit=101)
@pytest.mark.parametrize(
"sort_by",
[
"created_at",
"-created_at",
"updated_at",
"-updated_at",
],
)
def test_query_valid_sort_options(self, sort_by):
"""Test all valid sort_by options."""
query = ConversationListQuery(sort_by=sort_by)
assert query.sort_by == sort_by
class TestConversationRenamePayload:
"""Test suite for ConversationRenamePayload Pydantic model."""
def test_payload_with_name(self):
"""Test payload with explicit name."""
payload = ConversationRenamePayload(name="My New Chat", auto_generate=False)
assert payload.name == "My New Chat"
assert payload.auto_generate is False
def test_payload_with_auto_generate(self):
"""Test payload with auto_generate enabled."""
payload = ConversationRenamePayload(auto_generate=True)
assert payload.auto_generate is True
assert payload.name is None
def test_payload_requires_name_when_auto_generate_false(self):
"""Test that name is required when auto_generate is False."""
with pytest.raises(ValueError) as exc_info:
ConversationRenamePayload(auto_generate=False)
assert "name is required when auto_generate is false" in str(exc_info.value)
def test_payload_requires_non_empty_name_when_auto_generate_false(self):
"""Test that empty string name is rejected."""
with pytest.raises(ValueError):
ConversationRenamePayload(name="", auto_generate=False)
def test_payload_requires_non_whitespace_name_when_auto_generate_false(self):
"""Test that whitespace-only name is rejected."""
with pytest.raises(ValueError):
ConversationRenamePayload(name=" ", auto_generate=False)
def test_payload_name_with_special_characters(self):
"""Test payload with name containing special characters."""
payload = ConversationRenamePayload(name="Chat #1 - (Test) & More!", auto_generate=False)
assert payload.name == "Chat #1 - (Test) & More!"
def test_payload_name_with_unicode(self):
"""Test payload with Unicode characters in name."""
payload = ConversationRenamePayload(name="对话 📝 Чат", auto_generate=False)
assert payload.name == "对话 📝 Чат"
class TestConversationVariablesQuery:
"""Test suite for ConversationVariablesQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = ConversationVariablesQuery()
assert query.last_id is None
assert query.limit == 20
assert query.variable_name is None
def test_query_with_variable_name(self):
"""Test query with valid variable_name filter."""
query = ConversationVariablesQuery(variable_name="user_preference")
assert query.variable_name == "user_preference"
def test_query_allows_hyphen_in_variable_name(self):
"""Test that hyphens are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="my-variable")
assert query.variable_name == "my-variable"
def test_query_allows_underscore_in_variable_name(self):
"""Test that underscores are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="my_variable")
assert query.variable_name == "my_variable"
def test_query_allows_period_in_variable_name(self):
"""Test that periods are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="config.setting")
assert query.variable_name == "config.setting"
def test_query_rejects_sql_injection_single_quote(self):
"""Test that single quotes are rejected (SQL injection prevention)."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="'; DROP TABLE users;--")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_double_quote(self):
"""Test that double quotes are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name='name"test')
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_semicolon(self):
"""Test that semicolons are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name;malicious")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_comment(self):
"""Test that SQL comments are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name--comment")
assert "invalid characters" in str(exc_info.value)
def test_query_rejects_special_characters(self):
"""Test that special characters are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name@domain")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_backticks(self):
"""Test that backticks are rejected (SQL injection prevention)."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="`table`")
assert "can only contain" in str(exc_info.value)
def test_query_pagination_limits(self):
"""Test query pagination limit boundaries."""
query_min = ConversationVariablesQuery(limit=1)
assert query_min.limit == 1
query_max = ConversationVariablesQuery(limit=100)
assert query_max.limit == 100
class TestConversationVariableUpdatePayload:
"""Test suite for ConversationVariableUpdatePayload Pydantic model."""
def test_payload_with_string_value(self):
"""Test payload with string value."""
payload = ConversationVariableUpdatePayload(value="hello")
assert payload.value == "hello"
def test_payload_with_number_value(self):
"""Test payload with number value."""
payload = ConversationVariableUpdatePayload(value=42)
assert payload.value == 42
def test_payload_with_float_value(self):
"""Test payload with float value."""
payload = ConversationVariableUpdatePayload(value=3.14159)
assert payload.value == 3.14159
def test_payload_with_list_value(self):
"""Test payload with list value."""
payload = ConversationVariableUpdatePayload(value=["a", "b", "c"])
assert payload.value == ["a", "b", "c"]
def test_payload_with_dict_value(self):
"""Test payload with dictionary value."""
payload = ConversationVariableUpdatePayload(value={"key": "value"})
assert payload.value == {"key": "value"}
def test_payload_with_none_value(self):
"""Test payload with None value."""
payload = ConversationVariableUpdatePayload(value=None)
assert payload.value is None
def test_payload_with_boolean_value(self):
"""Test payload with boolean value."""
payload = ConversationVariableUpdatePayload(value=True)
assert payload.value is True
def test_payload_with_nested_structure(self):
"""Test payload with deeply nested structure."""
nested = {"level1": {"level2": {"level3": ["a", "b", {"c": 123}]}}}
payload = ConversationVariableUpdatePayload(value=nested)
assert payload.value == nested
class TestConversationAppModeValidation:
"""Test app mode validation for conversation endpoints."""
@pytest.mark.parametrize(
"mode",
[
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT.value,
],
)
def test_chat_modes_are_valid_for_conversation_endpoints(self, mode):
"""Test that all chat modes are valid for conversation endpoints.
Verifies that CHAT, AGENT_CHAT, and ADVANCED_CHAT modes pass
validation without raising NotChatAppError.
"""
app = Mock(spec=App)
app.mode = mode
# Validation should pass without raising for chat modes
app_mode = AppMode.value_of(app.mode)
assert app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
def test_completion_mode_is_invalid_for_conversation_endpoints(self):
"""Test that COMPLETION mode is invalid for conversation endpoints.
Verifies that calling a conversation endpoint with a COMPLETION mode
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
with pytest.raises(NotChatAppError):
raise NotChatAppError()
def test_workflow_mode_is_invalid_for_conversation_endpoints(self):
"""Test that WORKFLOW mode is invalid for conversation endpoints.
Verifies that calling a conversation endpoint with a WORKFLOW mode
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
with pytest.raises(NotChatAppError):
raise NotChatAppError()
class TestConversationErrorTypes:
"""Test conversation-related error types."""
def test_conversation_not_exists_error(self):
"""Test ConversationNotExistsError exists and can be raised."""
error = services.errors.conversation.ConversationNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
def test_conversation_completed_error(self):
"""Test ConversationCompletedError exists."""
error = services.errors.conversation.ConversationCompletedError()
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
def test_last_conversation_not_exists_error(self):
"""Test LastConversationNotExistsError exists."""
error = services.errors.conversation.LastConversationNotExistsError()
assert isinstance(error, services.errors.conversation.LastConversationNotExistsError)
def test_conversation_variable_not_exists_error(self):
"""Test ConversationVariableNotExistsError exists."""
error = services.errors.conversation.ConversationVariableNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationVariableNotExistsError)
def test_conversation_variable_type_mismatch_error(self):
"""Test ConversationVariableTypeMismatchError exists."""
error = services.errors.conversation.ConversationVariableTypeMismatchError("Type mismatch")
assert isinstance(error, services.errors.conversation.ConversationVariableTypeMismatchError)
class TestConversationService:
"""Test ConversationService integration patterns."""
def test_pagination_by_last_id_method_exists(self):
"""Test that ConversationService.pagination_by_last_id exists."""
assert hasattr(ConversationService, "pagination_by_last_id")
assert callable(ConversationService.pagination_by_last_id)
def test_delete_method_exists(self):
"""Test that ConversationService.delete exists."""
assert hasattr(ConversationService, "delete")
assert callable(ConversationService.delete)
def test_rename_method_exists(self):
"""Test that ConversationService.rename exists."""
assert hasattr(ConversationService, "rename")
assert callable(ConversationService.rename)
def test_get_conversational_variable_method_exists(self):
"""Test that ConversationService.get_conversational_variable exists."""
assert hasattr(ConversationService, "get_conversational_variable")
assert callable(ConversationService.get_conversational_variable)
def test_update_conversation_variable_method_exists(self):
"""Test that ConversationService.update_conversation_variable exists."""
assert hasattr(ConversationService, "update_conversation_variable")
assert callable(ConversationService.update_conversation_variable)
@patch.object(ConversationService, "pagination_by_last_id")
def test_pagination_returns_expected_format(self, mock_pagination):
"""Test pagination returns expected data format."""
mock_result = Mock()
mock_result.data = []
mock_result.limit = 20
mock_result.has_more = False
mock_pagination.return_value = mock_result
result = ConversationService.pagination_by_last_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
last_id=None,
limit=20,
invoke_from=Mock(),
sort_by="-updated_at",
)
assert hasattr(result, "data")
assert hasattr(result, "limit")
assert hasattr(result, "has_more")
@patch.object(ConversationService, "rename")
def test_rename_returns_conversation(self, mock_rename):
"""Test rename returns updated conversation."""
mock_conversation = Mock()
mock_conversation.name = "New Name"
mock_rename.return_value = mock_conversation
result = ConversationService.rename(
app_model=Mock(spec=App),
conversation_id="conv_123",
user=Mock(spec=EndUser),
name="New Name",
auto_generate=False,
)
assert result.name == "New Name"
class TestConversationPayloadsController:
def test_rename_requires_name(self) -> None:
with pytest.raises(ValueError):
ConversationRenamePayload(auto_generate=False, name="")
def test_variables_query_invalid_name(self) -> None:
with pytest.raises(ValueError):
ConversationVariablesQuery(variable_name="bad;")
class TestConversationApiController:
def test_list_not_chat(self, app) -> None:
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub:
def __enter__(self):
return SimpleNamespace()
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(
ConversationService,
"pagination_by_last_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(LastConversationNotExistsError()),
)
conversation_module = sys.modules["controllers.service_api.app.conversation"]
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub())
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations?last_id=00000000-0000-0000-0000-000000000001&limit=20",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestConversationDetailApiController:
def test_delete_not_chat(self, app) -> None:
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationRenameApiController:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"rename",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationRenameApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/name",
method="POST",
json={"auto_generate": True},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationVariablesApiController:
def test_not_chat(self, app) -> None:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1/variables", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables?limit=20",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableTypeMismatchError("bad")),
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": "x"},
):
with pytest.raises(BadRequest):
handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableNotExistsError()),
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": "x"},
):
with pytest.raises(NotFound):
handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)

View File

@ -0,0 +1,398 @@
"""
Unit tests for Service API File controllers.
Tests coverage for:
- File upload validation
- Error handling for file operations
- FileService integration
Focus on:
- File validation logic (size, type, filename)
- Error type mappings
- Service method interfaces
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from fields.file_fields import FileResponse
from services.file_service import FileService
class TestFileResponse:
"""Test suite for FileResponse Pydantic model."""
def test_file_response_has_required_fields(self):
"""Test FileResponse model includes required fields."""
# Verify the model exists and can be imported
assert FileResponse is not None
assert hasattr(FileResponse, "model_fields")
class TestFileUploadErrors:
"""Test file upload error types."""
def test_no_file_uploaded_error_can_be_raised(self):
"""Test NoFileUploadedError can be raised."""
error = NoFileUploadedError()
assert error is not None
def test_too_many_files_error_can_be_raised(self):
"""Test TooManyFilesError can be raised."""
error = TooManyFilesError()
assert error is not None
def test_unsupported_file_type_error_can_be_raised(self):
"""Test UnsupportedFileTypeError can be raised."""
error = UnsupportedFileTypeError()
assert error is not None
def test_filename_not_exists_error_can_be_raised(self):
"""Test FilenameNotExistsError can be raised."""
error = FilenameNotExistsError()
assert error is not None
def test_file_too_large_error_can_be_raised(self):
"""Test FileTooLargeError can be raised."""
error = FileTooLargeError("File exceeds maximum size")
assert "File exceeds maximum size" in str(error) or error is not None
class TestFileServiceErrors:
"""Test FileService error types."""
def test_file_service_file_too_large_error_exists(self):
"""Test FileTooLargeError from services exists."""
import services.errors.file
error = services.errors.file.FileTooLargeError("File too large")
assert isinstance(error, services.errors.file.FileTooLargeError)
def test_file_service_unsupported_file_type_error_exists(self):
"""Test UnsupportedFileTypeError from services exists."""
import services.errors.file
error = services.errors.file.UnsupportedFileTypeError()
assert isinstance(error, services.errors.file.UnsupportedFileTypeError)
class TestFileService:
"""Test FileService interface and methods."""
def test_upload_file_method_exists(self):
"""Test FileService.upload_file method exists."""
assert hasattr(FileService, "upload_file")
assert callable(FileService.upload_file)
@patch.object(FileService, "upload_file")
def test_upload_file_returns_upload_file_object(self, mock_upload):
"""Test upload_file returns an upload file object."""
mock_file = Mock()
mock_file.id = str(uuid.uuid4())
mock_file.name = "test.pdf"
mock_file.size = 1024
mock_file.extension = "pdf"
mock_file.mime_type = "application/pdf"
mock_upload.return_value = mock_file
# Call the method directly without instantiation
assert mock_file.name == "test.pdf"
assert mock_file.extension == "pdf"
@patch.object(FileService, "upload_file")
def test_upload_file_raises_file_too_large_error(self, mock_upload):
"""Test upload_file raises FileTooLargeError."""
import services.errors.file
mock_upload.side_effect = services.errors.file.FileTooLargeError("File exceeds 15MB limit")
# Verify error type exists
with pytest.raises(services.errors.file.FileTooLargeError):
mock_upload(Mock(), Mock(), "user_id")
@patch.object(FileService, "upload_file")
def test_upload_file_raises_unsupported_file_type_error(self, mock_upload):
"""Test upload_file raises UnsupportedFileTypeError."""
import services.errors.file
mock_upload.side_effect = services.errors.file.UnsupportedFileTypeError()
# Verify error type exists
with pytest.raises(services.errors.file.UnsupportedFileTypeError):
mock_upload(Mock(), Mock(), "user_id")
class TestFileValidation:
"""Test file validation patterns."""
def test_valid_image_mimetype(self):
"""Test common image MIME types."""
valid_mimetypes = ["image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"]
for mimetype in valid_mimetypes:
assert mimetype.startswith("image/")
def test_valid_document_mimetype(self):
"""Test common document MIME types."""
valid_mimetypes = [
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
"text/csv",
]
for mimetype in valid_mimetypes:
assert mimetype is not None
assert len(mimetype) > 0
def test_filename_has_extension(self):
"""Test filename validation for extension presence."""
valid_filenames = ["document.pdf", "image.png", "data.csv", "report.docx"]
for filename in valid_filenames:
assert "." in filename
parts = filename.rsplit(".", 1)
assert len(parts) == 2
assert len(parts[1]) > 0 # Extension exists
def test_filename_without_extension_is_invalid(self):
"""Test that filename without extension can be detected."""
filename = "noextension"
assert "." not in filename
class TestFileUploadResponse:
"""Test file upload response structure."""
@patch.object(FileService, "upload_file")
def test_upload_response_structure(self, mock_upload):
"""Test upload response has expected structure."""
mock_file = Mock()
mock_file.id = str(uuid.uuid4())
mock_file.name = "test.pdf"
mock_file.size = 2048
mock_file.extension = "pdf"
mock_file.mime_type = "application/pdf"
mock_file.created_by = str(uuid.uuid4())
mock_file.created_at = Mock()
mock_upload.return_value = mock_file
# Verify expected fields exist on mock
assert hasattr(mock_file, "id")
assert hasattr(mock_file, "name")
assert hasattr(mock_file, "size")
assert hasattr(mock_file, "extension")
assert hasattr(mock_file, "mime_type")
assert hasattr(mock_file, "created_by")
assert hasattr(mock_file, "created_at")
# =============================================================================
# API Endpoint Tests
#
# ``FileApi.post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
# which preserves ``__wrapped__`` via ``functools.wraps``. We call the
# unwrapped method directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_app_model():
from models import App
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_end_user():
from models import EndUser
user = Mock(spec=EndUser)
user.id = str(uuid.uuid4())
return user
class TestFileApiPost:
"""Test suite for FileApi.post() endpoint.
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
which preserves ``__wrapped__``.
"""
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_file_success(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test successful file upload."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
mock_upload = Mock()
mock_upload.id = str(uuid.uuid4())
mock_upload.name = "test.pdf"
mock_upload.size = 1024
mock_upload.extension = "pdf"
mock_upload.mime_type = "application/pdf"
mock_upload.created_by = str(mock_end_user.id)
mock_upload.created_by_role = "end_user"
mock_upload.created_at = 1700000000
mock_upload.preview_url = None
mock_upload.source_url = None
mock_upload.original_url = None
mock_upload.user_id = None
mock_upload.tenant_id = None
mock_upload.conversation_id = None
mock_upload.file_key = None
mock_file_svc_cls.return_value.upload_file.return_value = mock_upload
data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
response, status = _unwrap(api.post)(
api,
app_model=mock_app_model,
end_user=mock_end_user,
)
assert status == 201
mock_file_svc_cls.return_value.upload_file.assert_called_once()
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
"""Test NoFileUploadedError when no file in request."""
from controllers.service_api.app.file import FileApi
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data={},
):
api = FileApi()
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
data = {
"file": (BytesIO(b"content1"), "file1.pdf", "application/pdf"),
"extra": (BytesIO(b"content2"), "file2.pdf", "application/pdf"),
}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
data = {"file": (BytesIO(b"content"), "test.bin", "")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_file_too_large(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test FileTooLargeError when file exceeds size limit."""
from io import BytesIO
import services.errors.file
from controllers.service_api.app.file import FileApi
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
"File exceeds 15MB limit"
)
data = {"file": (BytesIO(b"big content"), "big.pdf", "application/pdf")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(FileTooLargeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_unsupported_file_type(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test UnsupportedFileTypeError from FileService."""
from io import BytesIO
import services.errors.file
from controllers.service_api.app.file import FileApi
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
data = {"file": (BytesIO(b"content"), "test.xyz", "application/octet-stream")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)

View File

@ -0,0 +1,541 @@
"""
Unit tests for Service API Message controllers.
Tests coverage for:
- MessageListQuery, MessageFeedbackPayload, FeedbackListQuery Pydantic models
- App mode validation for message endpoints
- MessageService integration
- Error handling for message operations
Focus on:
- Pydantic model validation
- UUID normalization
- Error type mappings
- Service method interfaces
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.app.message import (
AppGetFeedbacksApi,
FeedbackListQuery,
MessageFeedbackApi,
MessageFeedbackPayload,
MessageListApi,
MessageListQuery,
MessageSuggestedApi,
)
from models.model import App, AppMode, EndUser
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMessageListQuery:
"""Test suite for MessageListQuery Pydantic model."""
def test_query_requires_conversation_id(self):
"""Test conversation_id is required."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id)
assert str(query.conversation_id) == conversation_id
def test_query_with_defaults(self):
"""Test query with default values."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id)
assert query.first_id is None
assert query.limit == 20
def test_query_with_first_id(self):
"""Test query with first_id for pagination."""
conversation_id = str(uuid.uuid4())
first_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id, first_id=first_id)
assert str(query.first_id) == first_id
def test_query_with_custom_limit(self):
"""Test query with custom limit."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id, limit=50)
assert query.limit == 50
def test_query_limit_boundaries(self):
"""Test query respects limit boundaries."""
conversation_id = str(uuid.uuid4())
query_min = MessageListQuery(conversation_id=conversation_id, limit=1)
assert query_min.limit == 1
query_max = MessageListQuery(conversation_id=conversation_id, limit=100)
assert query_max.limit == 100
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
conversation_id = str(uuid.uuid4())
with pytest.raises(ValueError):
MessageListQuery(conversation_id=conversation_id, limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
conversation_id = str(uuid.uuid4())
with pytest.raises(ValueError):
MessageListQuery(conversation_id=conversation_id, limit=101)
class TestMessageFeedbackPayload:
"""Test suite for MessageFeedbackPayload Pydantic model."""
def test_payload_with_defaults(self):
"""Test payload with default values."""
payload = MessageFeedbackPayload()
assert payload.rating is None
assert payload.content is None
def test_payload_with_like_rating(self):
"""Test payload with like rating."""
payload = MessageFeedbackPayload(rating="like")
assert payload.rating == "like"
def test_payload_with_dislike_rating(self):
"""Test payload with dislike rating."""
payload = MessageFeedbackPayload(rating="dislike")
assert payload.rating == "dislike"
def test_payload_with_content_only(self):
"""Test payload with content but no rating."""
payload = MessageFeedbackPayload(content="This response was helpful")
assert payload.content == "This response was helpful"
assert payload.rating is None
def test_payload_with_rating_and_content(self):
"""Test payload with both rating and content."""
payload = MessageFeedbackPayload(rating="like", content="Great answer, very detailed!")
assert payload.rating == "like"
assert payload.content == "Great answer, very detailed!"
def test_payload_with_long_content(self):
"""Test payload with long feedback content."""
long_content = "A" * 1000
payload = MessageFeedbackPayload(content=long_content)
assert len(payload.content) == 1000
def test_payload_with_unicode_content(self):
"""Test payload with unicode characters."""
unicode_content = "很好的回答 👍 Отличный ответ"
payload = MessageFeedbackPayload(content=unicode_content)
assert payload.content == unicode_content
class TestFeedbackListQuery:
"""Test suite for FeedbackListQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = FeedbackListQuery()
assert query.page == 1
assert query.limit == 20
def test_query_with_custom_pagination(self):
"""Test query with custom page and limit."""
query = FeedbackListQuery(page=3, limit=50)
assert query.page == 3
assert query.limit == 50
def test_query_page_minimum(self):
"""Test query page minimum validation."""
query = FeedbackListQuery(page=1)
assert query.page == 1
def test_query_rejects_page_below_minimum(self):
"""Test query rejects page < 1."""
with pytest.raises(ValueError):
FeedbackListQuery(page=0)
def test_query_limit_boundaries(self):
"""Test query limit boundaries."""
query_min = FeedbackListQuery(limit=1)
assert query_min.limit == 1
query_max = FeedbackListQuery(limit=101)
assert query_max.limit == 101 # Max is 101
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
FeedbackListQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 101."""
with pytest.raises(ValueError):
FeedbackListQuery(limit=102)
class TestMessageAppModeValidation:
"""Test app mode validation for message endpoints."""
def test_chat_modes_are_valid_for_message_endpoints(self):
"""Test that all chat modes are valid."""
valid_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
for mode in valid_modes:
assert mode in valid_modes
def test_completion_mode_is_invalid_for_message_endpoints(self):
"""Test that COMPLETION mode is invalid."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.COMPLETION not in chat_modes
def test_workflow_mode_is_invalid_for_message_endpoints(self):
"""Test that WORKFLOW mode is invalid."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.WORKFLOW not in chat_modes
def test_not_chat_app_error_can_be_raised(self):
"""Test NotChatAppError can be raised."""
error = NotChatAppError()
assert error is not None
class TestMessageErrorTypes:
"""Test message-related error types."""
def test_message_not_exists_error_can_be_raised(self):
"""Test MessageNotExistsError can be raised."""
error = MessageNotExistsError()
assert isinstance(error, MessageNotExistsError)
def test_first_message_not_exists_error_can_be_raised(self):
"""Test FirstMessageNotExistsError can be raised."""
error = FirstMessageNotExistsError()
assert isinstance(error, FirstMessageNotExistsError)
def test_suggested_questions_after_answer_disabled_error_can_be_raised(self):
"""Test SuggestedQuestionsAfterAnswerDisabledError can be raised."""
error = SuggestedQuestionsAfterAnswerDisabledError()
assert isinstance(error, SuggestedQuestionsAfterAnswerDisabledError)
class TestMessageService:
"""Test MessageService interface and methods."""
def test_pagination_by_first_id_method_exists(self):
"""Test MessageService.pagination_by_first_id exists."""
assert hasattr(MessageService, "pagination_by_first_id")
assert callable(MessageService.pagination_by_first_id)
def test_create_feedback_method_exists(self):
"""Test MessageService.create_feedback exists."""
assert hasattr(MessageService, "create_feedback")
assert callable(MessageService.create_feedback)
def test_get_all_messages_feedbacks_method_exists(self):
"""Test MessageService.get_all_messages_feedbacks exists."""
assert hasattr(MessageService, "get_all_messages_feedbacks")
assert callable(MessageService.get_all_messages_feedbacks)
def test_get_suggested_questions_after_answer_method_exists(self):
"""Test MessageService.get_suggested_questions_after_answer exists."""
assert hasattr(MessageService, "get_suggested_questions_after_answer")
assert callable(MessageService.get_suggested_questions_after_answer)
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_by_first_id_returns_pagination_result(self, mock_pagination):
"""Test pagination_by_first_id returns expected format."""
mock_result = Mock()
mock_result.data = []
mock_result.limit = 20
mock_result.has_more = False
mock_pagination.return_value = mock_result
result = MessageService.pagination_by_first_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
conversation_id=str(uuid.uuid4()),
first_id=None,
limit=20,
)
assert hasattr(result, "data")
assert hasattr(result, "limit")
assert hasattr(result, "has_more")
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_raises_conversation_not_exists_error(self, mock_pagination):
"""Test pagination raises ConversationNotExistsError."""
import services.errors.conversation
mock_pagination.side_effect = services.errors.conversation.ConversationNotExistsError()
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
MessageService.pagination_by_first_id(
app_model=Mock(spec=App), user=Mock(spec=EndUser), conversation_id="invalid_id", first_id=None, limit=20
)
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_raises_first_message_not_exists_error(self, mock_pagination):
"""Test pagination raises FirstMessageNotExistsError."""
mock_pagination.side_effect = FirstMessageNotExistsError()
with pytest.raises(FirstMessageNotExistsError):
MessageService.pagination_by_first_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
conversation_id=str(uuid.uuid4()),
first_id="invalid_first_id",
limit=20,
)
@patch.object(MessageService, "create_feedback")
def test_create_feedback_with_rating_and_content(self, mock_create_feedback):
"""Test create_feedback with rating and content."""
mock_create_feedback.return_value = None
MessageService.create_feedback(
app_model=Mock(spec=App),
message_id=str(uuid.uuid4()),
user=Mock(spec=EndUser),
rating="like",
content="Great response!",
)
mock_create_feedback.assert_called_once()
@patch.object(MessageService, "create_feedback")
def test_create_feedback_raises_message_not_exists_error(self, mock_create_feedback):
"""Test create_feedback raises MessageNotExistsError."""
mock_create_feedback.side_effect = MessageNotExistsError()
with pytest.raises(MessageNotExistsError):
MessageService.create_feedback(
app_model=Mock(spec=App),
message_id="invalid_message_id",
user=Mock(spec=EndUser),
rating="like",
content=None,
)
@patch.object(MessageService, "get_all_messages_feedbacks")
def test_get_all_messages_feedbacks_returns_list(self, mock_get_feedbacks):
"""Test get_all_messages_feedbacks returns list of feedbacks."""
mock_feedbacks = [
{"message_id": str(uuid.uuid4()), "rating": "like"},
{"message_id": str(uuid.uuid4()), "rating": "dislike"},
]
mock_get_feedbacks.return_value = mock_feedbacks
result = MessageService.get_all_messages_feedbacks(app_model=Mock(spec=App), page=1, limit=20)
assert len(result) == 2
assert result[0]["rating"] == "like"
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_returns_questions_list(self, mock_get_questions):
"""Test get_suggested_questions_after_answer returns list of questions."""
mock_questions = ["What about this aspect?", "Can you elaborate on that?", "How does this relate to...?"]
mock_get_questions.return_value = mock_questions
result = MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
)
assert len(result) == 3
assert isinstance(result[0], str)
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_raises_disabled_error(self, mock_get_questions):
"""Test get_suggested_questions_after_answer raises SuggestedQuestionsAfterAnswerDisabledError."""
mock_get_questions.side_effect = SuggestedQuestionsAfterAnswerDisabledError()
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
)
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_raises_message_not_exists_error(self, mock_get_questions):
"""Test get_suggested_questions_after_answer raises MessageNotExistsError."""
mock_get_questions.side_effect = MessageNotExistsError()
with pytest.raises(MessageNotExistsError):
MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id="invalid_message_id", invoke_from=Mock()
)
class TestMessageListApi:
def test_not_chat_app(self, app) -> None:
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages?conversation_id=cid", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/messages?conversation_id=00000000-0000-0000-0000-000000000001",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(FirstMessageNotExistsError()),
)
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/messages?conversation_id=00000000-0000-0000-0000-000000000001&first_id=00000000-0000-0000-0000-000000000002",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestMessageFeedbackApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"create_feedback",
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
)
api = MessageFeedbackApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace()
end_user = SimpleNamespace()
with app.test_request_context(
"/messages/m1/feedbacks",
method="POST",
json={"rating": "like", "content": "ok"},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
class TestAppGetFeedbacksApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace()
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
response = handler(api, app_model=app_model)
assert response == {"data": ["f1"]}
class TestMessageSuggestedApi:
def test_not_chat(self, app) -> None:
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(SuggestedQuestionsAfterAnswerDisabledError()),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: ["q1"],
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, message_id="m1")
assert response == {"result": "success", "data": ["q1"]}

View File

@ -0,0 +1,653 @@
"""
Unit tests for Service API Workflow controllers.
Tests coverage for:
- WorkflowRunPayload and WorkflowLogQuery Pydantic models
- Workflow execution error handling
- App mode validation for workflow endpoints
- Workflow stop mechanism validation
Focus on:
- Pydantic model validation
- Error type mappings
- Service method interfaces
"""
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow import (
AppQueueManager,
DifyAPIRepositoryFactory,
GraphEngineManager,
WorkflowAppLogApi,
WorkflowLogQuery,
WorkflowRunApi,
WorkflowRunByIdApi,
WorkflowRunDetailApi,
WorkflowRunPayload,
WorkflowTaskStopApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.workflow.enums import WorkflowExecutionStatus
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
class TestWorkflowRunPayload:
"""Test suite for WorkflowRunPayload Pydantic model."""
def test_payload_with_required_inputs(self):
"""Test payload with required inputs field."""
payload = WorkflowRunPayload(inputs={"key": "value"})
assert payload.inputs == {"key": "value"}
assert payload.files is None
assert payload.response_mode is None
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
files = [{"type": "image", "url": "http://example.com/img.png"}]
payload = WorkflowRunPayload(inputs={"param1": "value1", "param2": 123}, files=files, response_mode="streaming")
assert payload.inputs == {"param1": "value1", "param2": 123}
assert payload.files == files
assert payload.response_mode == "streaming"
def test_payload_response_mode_blocking(self):
"""Test payload with blocking response mode."""
payload = WorkflowRunPayload(inputs={}, response_mode="blocking")
assert payload.response_mode == "blocking"
def test_payload_with_complex_inputs(self):
"""Test payload with nested complex inputs."""
complex_inputs = {
"config": {"nested": {"value": 123}},
"items": ["item1", "item2"],
"metadata": {"key": "value"},
}
payload = WorkflowRunPayload(inputs=complex_inputs)
assert payload.inputs == complex_inputs
def test_payload_with_empty_inputs(self):
"""Test payload with empty inputs dict."""
payload = WorkflowRunPayload(inputs={})
assert payload.inputs == {}
def test_payload_with_multiple_files(self):
"""Test payload with multiple file attachments."""
files = [
{"type": "image", "url": "http://example.com/img1.png"},
{"type": "document", "upload_file_id": "file_123"},
{"type": "audio", "url": "http://example.com/audio.mp3"},
]
payload = WorkflowRunPayload(inputs={}, files=files)
assert len(payload.files) == 3
class TestWorkflowLogQuery:
"""Test suite for WorkflowLogQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = WorkflowLogQuery()
assert query.keyword is None
assert query.status is None
assert query.created_at__before is None
assert query.created_at__after is None
assert query.created_by_end_user_session_id is None
assert query.created_by_account is None
assert query.page == 1
assert query.limit == 20
def test_query_with_all_filters(self):
"""Test query with all filter fields populated."""
query = WorkflowLogQuery(
keyword="search term",
status="succeeded",
created_at__before="2024-01-15T10:00:00Z",
created_at__after="2024-01-01T00:00:00Z",
created_by_end_user_session_id="session_123",
created_by_account="user@example.com",
page=2,
limit=50,
)
assert query.keyword == "search term"
assert query.status == "succeeded"
assert query.created_at__before == "2024-01-15T10:00:00Z"
assert query.created_at__after == "2024-01-01T00:00:00Z"
assert query.created_by_end_user_session_id == "session_123"
assert query.created_by_account == "user@example.com"
assert query.page == 2
assert query.limit == 50
@pytest.mark.parametrize("status", ["succeeded", "failed", "stopped"])
def test_query_valid_status_values(self, status):
"""Test all valid status values."""
query = WorkflowLogQuery(status=status)
assert query.status == status
def test_query_pagination_limits(self):
"""Test query pagination boundaries."""
query_min_page = WorkflowLogQuery(page=1)
assert query_min_page.page == 1
query_max_page = WorkflowLogQuery(page=99999)
assert query_max_page.page == 99999
query_min_limit = WorkflowLogQuery(limit=1)
assert query_min_limit.limit == 1
query_max_limit = WorkflowLogQuery(limit=100)
assert query_max_limit.limit == 100
def test_query_rejects_page_below_minimum(self):
"""Test query rejects page < 1."""
with pytest.raises(ValueError):
WorkflowLogQuery(page=0)
def test_query_rejects_page_above_maximum(self):
"""Test query rejects page > 99999."""
with pytest.raises(ValueError):
WorkflowLogQuery(page=100000)
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
WorkflowLogQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
with pytest.raises(ValueError):
WorkflowLogQuery(limit=101)
def test_query_with_keyword_search(self):
"""Test query with keyword filter."""
query = WorkflowLogQuery(keyword="workflow execution")
assert query.keyword == "workflow execution"
def test_query_with_date_filters(self):
"""Test query with before/after date filters."""
query = WorkflowLogQuery(created_at__before="2024-12-31T23:59:59Z", created_at__after="2024-01-01T00:00:00Z")
assert query.created_at__before == "2024-12-31T23:59:59Z"
assert query.created_at__after == "2024-01-01T00:00:00Z"
class TestWorkflowAppService:
"""Test WorkflowAppService interface."""
def test_service_exists(self):
"""Test WorkflowAppService class exists."""
service = WorkflowAppService()
assert service is not None
def test_get_paginate_workflow_app_logs_method_exists(self):
"""Test get_paginate_workflow_app_logs method exists."""
assert hasattr(WorkflowAppService, "get_paginate_workflow_app_logs")
assert callable(WorkflowAppService.get_paginate_workflow_app_logs)
@patch.object(WorkflowAppService, "get_paginate_workflow_app_logs")
def test_get_paginate_workflow_app_logs_returns_pagination(self, mock_get_logs):
"""Test get_paginate_workflow_app_logs returns paginated result."""
mock_pagination = Mock()
mock_pagination.data = []
mock_pagination.page = 1
mock_pagination.limit = 20
mock_pagination.total = 0
mock_get_logs.return_value = mock_pagination
service = WorkflowAppService()
result = service.get_paginate_workflow_app_logs(
session=Mock(),
app_model=Mock(spec=App),
keyword=None,
status=None,
created_at_before=None,
created_at_after=None,
page=1,
limit=20,
created_by_end_user_session_id=None,
created_by_account=None,
)
assert result.page == 1
assert result.limit == 20
class TestWorkflowExecutionStatus:
"""Test WorkflowExecutionStatus enum."""
def test_succeeded_status_exists(self):
"""Test succeeded status value exists."""
status = WorkflowExecutionStatus("succeeded")
assert status.value == "succeeded"
def test_failed_status_exists(self):
"""Test failed status value exists."""
status = WorkflowExecutionStatus("failed")
assert status.value == "failed"
def test_stopped_status_exists(self):
"""Test stopped status value exists."""
status = WorkflowExecutionStatus("stopped")
assert status.value == "stopped"
class TestAppGenerateServiceWorkflow:
"""Test AppGenerateService workflow integration."""
@patch.object(AppGenerateService, "generate")
def test_generate_accepts_workflow_args(self, mock_generate):
"""Test generate accepts workflow-specific args."""
mock_generate.return_value = {"result": "success"}
result = AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"inputs": {"key": "value"}, "workflow_id": "workflow_123"},
invoke_from=Mock(),
streaming=False,
)
assert result == {"result": "success"}
mock_generate.assert_called_once()
@patch.object(AppGenerateService, "generate")
def test_generate_raises_workflow_not_found_error(self, mock_generate):
"""Test generate raises WorkflowNotFoundError."""
mock_generate.side_effect = WorkflowNotFoundError("Workflow not found")
with pytest.raises(WorkflowNotFoundError):
AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"workflow_id": "invalid_id"},
invoke_from=Mock(),
streaming=False,
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_is_draft_workflow_error(self, mock_generate):
"""Test generate raises IsDraftWorkflowError."""
mock_generate.side_effect = IsDraftWorkflowError("Workflow is draft")
with pytest.raises(IsDraftWorkflowError):
AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"workflow_id": "draft_workflow"},
invoke_from=Mock(),
streaming=False,
)
@patch.object(AppGenerateService, "generate")
def test_generate_supports_streaming_mode(self, mock_generate):
"""Test generate supports streaming response mode."""
mock_stream = Mock()
mock_generate.return_value = mock_stream
result = AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"inputs": {}, "response_mode": "streaming"},
invoke_from=Mock(),
streaming=True,
)
assert result == mock_stream
class TestWorkflowStopMechanism:
"""Test workflow stop mechanisms."""
def test_app_queue_manager_has_stop_flag_method(self):
"""Test AppQueueManager has set_stop_flag_no_user_check method."""
from core.app.apps.base_app_queue_manager import AppQueueManager
assert hasattr(AppQueueManager, "set_stop_flag_no_user_check")
def test_graph_engine_manager_has_send_stop_command(self):
"""Test GraphEngineManager has send_stop_command method."""
from core.workflow.graph_engine.manager import GraphEngineManager
assert hasattr(GraphEngineManager, "send_stop_command")
class TestWorkflowRunRepository:
"""Test workflow run repository interface."""
def test_repository_factory_can_create_workflow_run_repository(self):
"""Test DifyAPIRepositoryFactory can create workflow run repository."""
from repositories.factory import DifyAPIRepositoryFactory
assert hasattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository")
@patch("repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository")
def test_workflow_run_repository_get_by_id(self, mock_factory):
"""Test workflow run repository get_workflow_run_by_id method."""
mock_repo = Mock()
mock_run = Mock()
mock_run.id = str(uuid.uuid4())
mock_run.status = "succeeded"
mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_factory.return_value = mock_repo
from repositories.factory import DifyAPIRepositoryFactory
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(Mock())
result = repo.get_workflow_run_by_id(tenant_id="tenant_123", app_id="app_456", run_id="run_789")
assert result.status == "succeeded"
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
with app.test_request_context("/workflows/run/1", method="GET"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, workflow_run_id="run")
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
run = SimpleNamespace(id="run")
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
assert handler(api, app_model=app_model, workflow_run_id="run") == run
class TestWorkflowRunApi:
def test_not_workflow_app(self, app) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(InvokeRateLimitError("slow")),
)
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
with pytest.raises(InvokeRateLimitHttpError):
handler(api, app_model=app_model, end_user=end_user)
class TestWorkflowRunByIdApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
send_mock = Mock()
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
assert response == {"result": "success"}
stop_mock.assert_called_once_with("t1")
send_mock.assert_called_once_with("t1")
class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub:
def __enter__(self):
return SimpleNamespace()
def __exit__(self, exc_type, exc, tb):
return False
workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub())
monkeypatch.setattr(
WorkflowAppService,
"get_paginate_workflow_app_logs",
lambda *_args, **_kwargs: {"items": [], "total": 0},
)
api = WorkflowAppLogApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/workflows/logs", method="GET"):
response = handler(api, app_model=app_model)
assert response == {"items": [], "total": 0}
# =============================================================================
# API Endpoint Tests
#
# ``WorkflowRunDetailApi``, ``WorkflowTaskStopApi``, and
# ``WorkflowAppLogApi`` use ``@validate_app_token`` which preserves
# ``__wrapped__`` via ``functools.wraps``. We call the unwrapped method
# directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_workflow_app():
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.mode = AppMode.WORKFLOW.value
return app
class TestWorkflowRunDetailApiGet:
"""Test suite for WorkflowRunDetailApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
and ``@service_api_ns.marshal_with``. We call the unwrapped method
directly; ``marshal_with`` is a no-op when calling directly.
"""
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_success(
self,
mock_db,
mock_repo_factory,
app,
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
mock_run = Mock()
mock_run.id = "run-1"
mock_run.status = "succeeded"
mock_repo = Mock()
mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
from controllers.service_api.app.workflow import WorkflowRunDetailApi
with app.test_request_context(
f"/workflows/run/{mock_run.id}",
method="GET",
):
api = WorkflowRunDetailApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
assert result == mock_run
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
from controllers.service_api.app.workflow import WorkflowRunDetailApi
mock_app = Mock(spec=App)
mock_app.mode = AppMode.CHAT.value
with app.test_request_context("/workflows/run/run-1", method="GET"):
api = WorkflowRunDetailApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
class TestWorkflowTaskStopApiPost:
"""Test suite for WorkflowTaskStopApi.post() endpoint.
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``.
"""
@patch("controllers.service_api.app.workflow.GraphEngineManager")
@patch("controllers.service_api.app.workflow.AppQueueManager")
def test_stop_workflow_task_success(
self,
mock_queue_mgr,
mock_graph_mgr,
app,
mock_workflow_app,
):
"""Test successful workflow task stop."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
result = _unwrap(api.post)(
api,
app_model=mock_workflow_app,
end_user=Mock(),
task_id="task-1",
)
assert result == {"result": "success"}
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
"""Test NotWorkflowAppError when app mode is not workflow."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
mock_app = Mock(spec=App)
mock_app.mode = AppMode.COMPLETION.value
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
class TestWorkflowAppLogApiGet:
"""Test suite for WorkflowAppLogApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` and
``@service_api_ns.marshal_with``.
"""
@patch("controllers.service_api.app.workflow.WorkflowAppService")
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_logs_success(
self,
mock_db,
mock_wf_svc_cls,
app,
mock_workflow_app,
):
"""Test successful workflow log retrieval."""
mock_pagination = Mock()
mock_pagination.data = []
mock_svc_instance = Mock()
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
mock_wf_svc_cls.return_value = mock_svc_instance
# Mock Session context manager
mock_session = Mock()
mock_db.engine = Mock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=False)
from controllers.service_api.app.workflow import WorkflowAppLogApi
with app.test_request_context(
"/workflows/logs?page=1&limit=20",
method="GET",
):
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session):
api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
assert result == mock_pagination

View File

@ -0,0 +1,218 @@
"""
Shared fixtures for Service API controller tests.
This module provides reusable fixtures for mocking authentication,
database interactions, and common test data patterns used across
Service API controller tests.
"""
import uuid
from unittest.mock import Mock
import pytest
from flask import Flask
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
@pytest.fixture
def app():
"""Create Flask test application with proper configuration."""
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture
def mock_tenant_id():
"""Generate a consistent tenant ID for test sessions."""
return str(uuid.uuid4())
@pytest.fixture
def mock_app_id():
"""Generate a consistent app ID for test sessions."""
return str(uuid.uuid4())
@pytest.fixture
def mock_end_user(mock_tenant_id):
"""Create a mock EndUser model with required attributes."""
user = Mock(spec=EndUser)
user.id = str(uuid.uuid4())
user.external_user_id = f"external_{uuid.uuid4().hex[:8]}"
user.tenant_id = mock_tenant_id
return user
@pytest.fixture
def mock_app_model(mock_app_id, mock_tenant_id):
"""Create a mock App model with all required attributes for API testing."""
app = Mock(spec=App)
app.id = mock_app_id
app.tenant_id = mock_tenant_id
app.name = "Test App"
app.description = "A test application"
app.mode = AppMode.CHAT
app.author_name = "Test Author"
app.status = "normal"
app.enable_api = True
app.tags = []
# Mock workflow for workflow apps
app.workflow = None
app.app_model_config = None
return app
@pytest.fixture
def mock_tenant(mock_tenant_id):
"""Create a mock Tenant model."""
tenant = Mock()
tenant.id = mock_tenant_id
tenant.status = TenantStatus.NORMAL
return tenant
@pytest.fixture
def mock_account():
"""Create a mock Account model."""
account = Mock()
account.id = str(uuid.uuid4())
return account
@pytest.fixture
def mock_api_token(mock_app_id, mock_tenant_id):
"""Create a mock API token for authentication tests."""
token = Mock()
token.app_id = mock_app_id
token.tenant_id = mock_tenant_id
token.token = f"test_token_{uuid.uuid4().hex[:8]}"
token.type = "app"
return token
@pytest.fixture
def mock_dataset_api_token(mock_tenant_id):
"""Create a mock API token for dataset endpoints."""
token = Mock()
token.tenant_id = mock_tenant_id
token.token = f"dataset_token_{uuid.uuid4().hex[:8]}"
token.type = "dataset"
return token
class AuthenticationMocker:
"""
Helper class to set up common authentication mocking patterns.
Usage:
auth_mocker = AuthenticationMocker()
with auth_mocker.mock_app_auth(mock_api_token, mock_app_model, mock_tenant):
# Test code here
"""
@staticmethod
def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None):
"""Configure mock_db to return app and tenant in sequence."""
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
if mock_account:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
@staticmethod
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
"""Configure mock_db for dataset token authentication."""
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_query = mock_db.session.query.return_value
target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value
target_mock.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
@pytest.fixture
def auth_mocker():
"""Provide an AuthenticationMocker instance."""
return AuthenticationMocker()
@pytest.fixture
def mock_dataset():
"""Create a mock Dataset model."""
from models.dataset import Dataset
dataset = Mock(spec=Dataset)
dataset.id = str(uuid.uuid4())
dataset.tenant_id = str(uuid.uuid4())
dataset.name = "Test Dataset"
dataset.indexing_technique = "economy"
dataset.embedding_model = None
dataset.embedding_model_provider = None
return dataset
@pytest.fixture
def mock_document():
"""Create a mock Document model."""
from models.dataset import Document
document = Mock(spec=Document)
document.id = str(uuid.uuid4())
document.dataset_id = str(uuid.uuid4())
document.tenant_id = str(uuid.uuid4())
document.name = "test_document.txt"
document.indexing_status = "completed"
document.enabled = True
document.doc_form = "text_model"
return document
@pytest.fixture
def mock_segment():
"""Create a mock DocumentSegment model."""
from models.dataset import DocumentSegment
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = str(uuid.uuid4())
segment.dataset_id = str(uuid.uuid4())
segment.tenant_id = str(uuid.uuid4())
segment.content = "Test segment content"
segment.word_count = 3
segment.position = 1
segment.enabled = True
segment.status = "completed"
return segment
@pytest.fixture
def mock_child_chunk():
"""Create a mock ChildChunk model."""
from models.dataset import ChildChunk
child_chunk = Mock(spec=ChildChunk)
child_chunk.id = str(uuid.uuid4())
child_chunk.segment_id = str(uuid.uuid4())
child_chunk.tenant_id = str(uuid.uuid4())
child_chunk.content = "Test child chunk content"
return child_chunk
def _unwrap(method):
"""Walk ``__wrapped__`` chain to get the original function."""
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn

View File

@ -0,0 +1,633 @@
"""
Unit tests for Service API RAG Pipeline Workflow controllers.
Tests coverage for:
- DatasourceNodeRunPayload Pydantic model
- PipelineRunApiEntity / DatasourceNodeRunApiEntity model validation
- RAG pipeline service interfaces
- File upload validation for pipelines
- Endpoint tests for DatasourcePluginsApi, DatasourceNodeRunApi,
PipelineRunApi, and KnowledgebasePipelineFileUploadApi
Strategy:
- Endpoint methods on these resources have no billing decorators on the method
itself. ``method_decorators = [validate_dataset_token]`` is only invoked by
Flask-RESTx dispatch, not by direct calls, so we call methods directly.
- Only ``KnowledgebasePipelineFileUploadApi.post`` touches ``db`` inline
(via ``FileService(db.engine)``); the other endpoints delegate to services.
"""
import io
import uuid
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow import (
DatasourceNodeRunApi,
DatasourceNodeRunPayload,
DatasourcePluginsApi,
KnowledgebasePipelineFileUploadApi,
PipelineRunApi,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.rag_pipeline import RagPipelineService
class TestDatasourceNodeRunPayload:
"""Test suite for DatasourceNodeRunPayload Pydantic model."""
def test_payload_with_required_fields(self):
"""Test payload with required fields."""
payload = DatasourceNodeRunPayload(
inputs={"key": "value"}, datasource_type="online_document", is_published=True
)
assert payload.inputs == {"key": "value"}
assert payload.datasource_type == "online_document"
assert payload.is_published is True
assert payload.credential_id is None
def test_payload_with_credential_id(self):
"""Test payload with optional credential_id."""
payload = DatasourceNodeRunPayload(
inputs={"url": "https://example.com"},
datasource_type="online_document",
credential_id="cred_123",
is_published=False,
)
assert payload.credential_id == "cred_123"
assert payload.is_published is False
def test_payload_with_complex_inputs(self):
"""Test payload with complex nested inputs."""
complex_inputs = {
"config": {"url": "https://api.example.com", "headers": {"Authorization": "Bearer token"}},
"parameters": {"limit": 100, "offset": 0},
"options": ["opt1", "opt2"],
}
payload = DatasourceNodeRunPayload(inputs=complex_inputs, datasource_type="api", is_published=True)
assert payload.inputs == complex_inputs
def test_payload_with_empty_inputs(self):
"""Test payload with empty inputs dict."""
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="local_file", is_published=True)
assert payload.inputs == {}
@pytest.mark.parametrize("datasource_type", ["online_document", "local_file", "api", "database", "website"])
def test_payload_common_datasource_types(self, datasource_type):
"""Test payload with common datasource types."""
payload = DatasourceNodeRunPayload(inputs={}, datasource_type=datasource_type, is_published=True)
assert payload.datasource_type == datasource_type
class TestPipelineErrors:
"""Test pipeline-related error types."""
def test_pipeline_run_error_can_be_raised(self):
"""Test PipelineRunError can be raised."""
error = PipelineRunError(description="Pipeline execution failed")
assert error is not None
def test_pipeline_run_error_with_description(self):
"""Test PipelineRunError captures description."""
error = PipelineRunError(description="Timeout during node execution")
# The error should have the description attribute
assert hasattr(error, "description")
class TestFileUploadErrors:
"""Test file upload error types for pipelines."""
def test_no_file_uploaded_error(self):
"""Test NoFileUploadedError can be raised."""
error = NoFileUploadedError()
assert error is not None
def test_too_many_files_error(self):
"""Test TooManyFilesError can be raised."""
error = TooManyFilesError()
assert error is not None
def test_filename_not_exists_error(self):
"""Test FilenameNotExistsError can be raised."""
error = FilenameNotExistsError()
assert error is not None
def test_file_too_large_error(self):
"""Test FileTooLargeError can be raised."""
error = FileTooLargeError("File exceeds size limit")
assert error is not None
def test_unsupported_file_type_error(self):
"""Test UnsupportedFileTypeError can be raised."""
error = UnsupportedFileTypeError()
assert error is not None
class TestRagPipelineService:
"""Test RagPipelineService interface."""
def test_get_datasource_plugins_method_exists(self):
"""Test RagPipelineService.get_datasource_plugins exists."""
assert hasattr(RagPipelineService, "get_datasource_plugins")
def test_get_pipeline_method_exists(self):
"""Test RagPipelineService.get_pipeline exists."""
assert hasattr(RagPipelineService, "get_pipeline")
def test_run_datasource_workflow_node_method_exists(self):
"""Test RagPipelineService.run_datasource_workflow_node exists."""
assert hasattr(RagPipelineService, "run_datasource_workflow_node")
def test_get_pipeline_templates_method_exists(self):
"""Test RagPipelineService.get_pipeline_templates exists."""
assert hasattr(RagPipelineService, "get_pipeline_templates")
def test_get_pipeline_template_detail_method_exists(self):
"""Test RagPipelineService.get_pipeline_template_detail exists."""
assert hasattr(RagPipelineService, "get_pipeline_template_detail")
class TestInvokeFrom:
"""Test InvokeFrom enum for pipeline invocation."""
def test_published_pipeline_invoke_from(self):
"""Test PUBLISHED_PIPELINE InvokeFrom value exists."""
assert hasattr(InvokeFrom, "PUBLISHED_PIPELINE")
def test_debugger_invoke_from(self):
"""Test DEBUGGER InvokeFrom value exists."""
assert hasattr(InvokeFrom, "DEBUGGER")
class TestPipelineResponseModes:
"""Test pipeline response mode patterns."""
def test_streaming_mode(self):
"""Test streaming response mode."""
mode = "streaming"
valid_modes = ["streaming", "blocking"]
assert mode in valid_modes
def test_blocking_mode(self):
"""Test blocking response mode."""
mode = "blocking"
valid_modes = ["streaming", "blocking"]
assert mode in valid_modes
class TestDatasourceTypes:
"""Test common datasource types for pipelines."""
@pytest.mark.parametrize("ds_type", ["online_document", "local_file", "website", "api", "database"])
def test_datasource_type_valid(self, ds_type):
"""Test common datasource types are strings."""
assert isinstance(ds_type, str)
assert len(ds_type) > 0
class TestPipelineFileUploadResponse:
"""Test file upload response structure for pipelines."""
def test_upload_response_fields(self):
"""Test expected fields in upload response."""
expected_fields = ["id", "name", "size", "extension", "mime_type", "created_by", "created_at"]
# Create mock response
mock_response = {
"id": str(uuid.uuid4()),
"name": "document.pdf",
"size": 1024,
"extension": "pdf",
"mime_type": "application/pdf",
"created_by": str(uuid.uuid4()),
"created_at": "2024-01-01T00:00:00Z",
}
for field in expected_fields:
assert field in mock_response
class TestPipelineNodeExecution:
"""Test pipeline node execution patterns."""
def test_node_id_is_string(self):
"""Test node_id is a string identifier."""
node_id = "node_abc123"
assert isinstance(node_id, str)
assert len(node_id) > 0
def test_pipeline_id_is_uuid(self):
"""Test pipeline_id is a valid UUID string."""
pipeline_id = str(uuid.uuid4())
assert len(pipeline_id) == 36
assert "-" in pipeline_id
class TestCredentialHandling:
"""Test credential handling patterns."""
def test_credential_id_is_optional(self):
"""Test credential_id can be None."""
payload = DatasourceNodeRunPayload(
inputs={}, datasource_type="local_file", is_published=True, credential_id=None
)
assert payload.credential_id is None
def test_credential_id_can_be_provided(self):
"""Test credential_id can be set."""
payload = DatasourceNodeRunPayload(
inputs={}, datasource_type="api", is_published=True, credential_id="cred_oauth_123"
)
assert payload.credential_id == "cred_oauth_123"
class TestPublishedVsDraft:
"""Test published vs draft pipeline patterns."""
def test_is_published_true(self):
"""Test is_published=True for published pipelines."""
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="online_document", is_published=True)
assert payload.is_published is True
def test_is_published_false_for_draft(self):
"""Test is_published=False for draft pipelines."""
payload = DatasourceNodeRunPayload(inputs={}, datasource_type="online_document", is_published=False)
assert payload.is_published is False
class TestPipelineInputVariables:
"""Test pipeline input variable patterns."""
def test_inputs_as_dict(self):
"""Test inputs are passed as dictionary."""
inputs = {"url": "https://example.com/doc.pdf", "timeout": 30, "retry": True}
payload = DatasourceNodeRunPayload(inputs=inputs, datasource_type="online_document", is_published=True)
assert payload.inputs["url"] == "https://example.com/doc.pdf"
assert payload.inputs["timeout"] == 30
assert payload.inputs["retry"] is True
def test_inputs_with_list_values(self):
"""Test inputs with list values."""
inputs = {"urls": ["https://example.com/1", "https://example.com/2"], "tags": ["tag1", "tag2", "tag3"]}
payload = DatasourceNodeRunPayload(inputs=inputs, datasource_type="online_document", is_published=True)
assert len(payload.inputs["urls"]) == 2
assert len(payload.inputs["tags"]) == 3
# ---------------------------------------------------------------------------
# PipelineRunApiEntity / DatasourceNodeRunApiEntity Model Tests
# ---------------------------------------------------------------------------
class TestPipelineRunApiEntity:
"""Test PipelineRunApiEntity Pydantic model."""
def test_entity_with_all_fields(self):
"""Test entity with all required fields."""
entity = PipelineRunApiEntity(
inputs={"key": "value"},
datasource_type="online_document",
datasource_info_list=[{"url": "https://example.com"}],
start_node_id="node_1",
is_published=True,
response_mode="streaming",
)
assert entity.datasource_type == "online_document"
assert entity.response_mode == "streaming"
assert entity.is_published is True
def test_entity_blocking_response_mode(self):
"""Test entity with blocking response mode."""
entity = PipelineRunApiEntity(
inputs={},
datasource_type="local_file",
datasource_info_list=[],
start_node_id="node_start",
is_published=False,
response_mode="blocking",
)
assert entity.response_mode == "blocking"
assert entity.is_published is False
def test_entity_missing_required_field(self):
"""Test entity raises on missing required field."""
with pytest.raises(ValueError):
PipelineRunApiEntity(
inputs={},
datasource_type="online_document",
# missing datasource_info_list, start_node_id, etc.
)
class TestDatasourceNodeRunApiEntity:
"""Test DatasourceNodeRunApiEntity Pydantic model."""
def test_entity_with_all_fields(self):
"""Test entity with all fields."""
entity = DatasourceNodeRunApiEntity(
pipeline_id=str(uuid.uuid4()),
node_id="node_abc",
inputs={"url": "https://example.com"},
datasource_type="website",
is_published=True,
)
assert entity.node_id == "node_abc"
assert entity.credential_id is None
def test_entity_with_credential(self):
"""Test entity with credential_id."""
entity = DatasourceNodeRunApiEntity(
pipeline_id=str(uuid.uuid4()),
node_id="node_xyz",
inputs={},
datasource_type="api",
credential_id="cred_123",
is_published=False,
)
assert entity.credential_id == "cred_123"
# ---------------------------------------------------------------------------
# Endpoint Tests
# ---------------------------------------------------------------------------
class TestDatasourcePluginsApiGet:
"""Tests for DatasourcePluginsApi.get().
The original source delegates directly to ``RagPipelineService`` without
an inline dataset query, so no ``db`` patching is needed.
"""
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_success(self, mock_svc_cls, mock_db, app):
"""Test successful retrieval of datasource plugins."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_db.session.scalar.return_value = mock_dataset
mock_svc_instance = Mock()
mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}]
mock_svc_cls.return_value = mock_svc_instance
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"):
api = DatasourcePluginsApi()
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
assert status == 200
assert response == [{"name": "plugin_a"}]
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=True
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_get_plugins_not_found(self, mock_db, app):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
with app.test_request_context("/datasets/test/pipeline/datasource-plugins"):
api = DatasourcePluginsApi()
with pytest.raises(NotFound):
api.get(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app):
"""Test empty plugin list."""
mock_db.session.scalar.return_value = Mock()
mock_svc_instance = Mock()
mock_svc_instance.get_datasource_plugins.return_value = []
mock_svc_cls.return_value = mock_svc_instance
with app.test_request_context("/datasets/test/pipeline/datasource-plugins"):
api = DatasourcePluginsApi()
response, status = api.get(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
assert status == 200
assert response == []
class TestDatasourceNodeRunApiPost:
"""Tests for DatasourceNodeRunApi.post().
The source asserts ``isinstance(current_user, Account)`` and delegates to
``RagPipelineService`` and ``PipelineGenerator``, so we patch those plus
``current_user`` and ``service_api_ns``.
"""
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.helper")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.PipelineGenerator")
@patch(
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
new_callable=lambda: Mock(spec=Account),
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app):
"""Test successful datasource node run."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
node_id = "node_abc"
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
"inputs": {"url": "https://example.com"},
"datasource_type": "online_document",
"is_published": True,
}
mock_pipeline = Mock()
mock_pipeline.id = str(uuid.uuid4())
mock_svc_instance = Mock()
mock_svc_instance.get_pipeline.return_value = mock_pipeline
mock_svc_instance.run_datasource_workflow_node.return_value = iter(["event1"])
mock_svc_cls.return_value = mock_svc_instance
mock_gen.convert_to_event_stream.return_value = iter(["stream_event"])
mock_helper.compact_generate_response.return_value = {"result": "ok"}
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/node_abc/run", method="POST"):
api = DatasourceNodeRunApi()
response = api.post(tenant_id=tenant_id, dataset_id=dataset_id, node_id=node_id)
assert response == {"result": "ok"}
mock_svc_instance.get_pipeline.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id)
mock_svc_instance.get_pipeline.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id)
mock_svc_instance.run_datasource_workflow_node.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/n1/run", method="POST"):
api = DatasourceNodeRunApi()
with pytest.raises(NotFound):
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()), node_id="n1")
@patch(
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
new="not_account",
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app):
"""Test AssertionError when current_user is not an Account instance."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
"inputs": {},
"datasource_type": "local_file",
"is_published": True,
}
with app.test_request_context("/datasets/test/pipeline/datasource/nodes/n1/run", method="POST"):
api = DatasourceNodeRunApi()
with pytest.raises(AssertionError):
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()), node_id="n1")
class TestPipelineRunApiPost:
"""Tests for PipelineRunApi.post()."""
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.helper")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService")
@patch(
"controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user",
new_callable=lambda: Mock(spec=Account),
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_success_streaming(
self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen_svc, mock_helper, app
):
"""Test successful pipeline run with streaming response."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
"inputs": {"key": "val"},
"datasource_type": "online_document",
"datasource_info_list": [],
"start_node_id": "node_1",
"is_published": True,
"response_mode": "streaming",
}
mock_pipeline = Mock()
mock_svc_instance = Mock()
mock_svc_instance.get_pipeline.return_value = mock_pipeline
mock_svc_cls.return_value = mock_svc_instance
mock_gen_svc.generate.return_value = {"result": "ok"}
mock_helper.compact_generate_response.return_value = {"result": "ok"}
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
api = PipelineRunApi()
response = api.post(tenant_id=tenant_id, dataset_id=dataset_id)
assert response == {"result": "ok"}
mock_gen_svc.generate.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
api = PipelineRunApi()
with pytest.raises(NotFound):
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app):
"""Test Forbidden when current_user is not an Account."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
"inputs": {},
"datasource_type": "online_document",
"datasource_info_list": [],
"start_node_id": "node_1",
"is_published": True,
"response_mode": "blocking",
}
with app.test_request_context("/datasets/test/pipeline/run", method="POST"):
api = PipelineRunApi()
with pytest.raises(Forbidden):
api.post(tenant_id=str(uuid.uuid4()), dataset_id=str(uuid.uuid4()))
class TestFileUploadApiPost:
"""Tests for KnowledgebasePipelineFileUploadApi.post()."""
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app):
"""Test successful file upload."""
mock_current_user.__bool__ = Mock(return_value=True)
mock_upload = Mock()
mock_upload.id = str(uuid.uuid4())
mock_upload.name = "doc.pdf"
mock_upload.size = 1024
mock_upload.extension = "pdf"
mock_upload.mime_type = "application/pdf"
mock_upload.created_by = str(uuid.uuid4())
mock_upload.created_at = datetime(2024, 1, 1, tzinfo=UTC)
mock_file_svc_instance = Mock()
mock_file_svc_instance.upload_file.return_value = mock_upload
mock_file_svc_cls.return_value = mock_file_svc_instance
file_data = FileStorage(
stream=io.BytesIO(b"fake pdf content"),
filename="doc.pdf",
content_type="application/pdf",
)
with app.test_request_context(
"/datasets/pipeline/file-upload",
method="POST",
content_type="multipart/form-data",
data={"file": file_data},
):
api = KnowledgebasePipelineFileUploadApi()
response, status = api.post(tenant_id=str(uuid.uuid4()))
assert status == 201
assert response["name"] == "doc.pdf"
assert response["extension"] == "pdf"
def test_upload_no_file(self, app):
"""Test error when no file is uploaded."""
with app.test_request_context(
"/datasets/pipeline/file-upload",
method="POST",
content_type="multipart/form-data",
):
api = KnowledgebasePipelineFileUploadApi()
with pytest.raises(NoFileUploadedError):
api.post(tenant_id=str(uuid.uuid4()))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,205 @@
"""
Unit tests for Service API HitTesting controller.
Tests coverage for:
- HitTestingPayload Pydantic model validation
- HitTestingApi endpoint (success and error paths via direct method calls)
Strategy:
- ``HitTestingApi.post`` is decorated with ``@cloud_edition_billing_rate_limit_check``
which preserves ``__wrapped__``. We call ``post.__wrapped__(self, ...)`` to skip
the billing decorator and test the business logic directly.
- Base-class methods (``get_and_validate_dataset``, ``perform_hit_testing``) read
``current_user`` from ``controllers.console.datasets.hit_testing_base``, so we
patch it there.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
from models.account import Account
# ---------------------------------------------------------------------------
# HitTestingPayload Model Tests
# ---------------------------------------------------------------------------
class TestHitTestingPayload:
"""Test suite for HitTestingPayload Pydantic model."""
def test_payload_with_required_query(self):
"""Test payload with required query field."""
payload = HitTestingPayload(query="test query")
assert payload.query == "test query"
def test_payload_with_all_fields(self):
"""Test payload with all optional fields."""
payload = HitTestingPayload(
query="test query",
retrieval_model={"top_k": 5},
external_retrieval_model={"provider": "openai"},
attachment_ids=["att_1", "att_2"],
)
assert payload.query == "test query"
assert payload.retrieval_model == {"top_k": 5}
assert payload.external_retrieval_model == {"provider": "openai"}
assert payload.attachment_ids == ["att_1", "att_2"]
def test_payload_query_too_long(self):
"""Test payload rejects query over 250 characters."""
with pytest.raises(ValueError):
HitTestingPayload(query="x" * 251)
def test_payload_query_at_max_length(self):
"""Test payload accepts query at exactly 250 characters."""
payload = HitTestingPayload(query="x" * 250)
assert len(payload.query) == 250
# ---------------------------------------------------------------------------
# HitTestingApi Tests
#
# We use ``post.__wrapped__`` to bypass ``@cloud_edition_billing_rate_limit_check``
# and call the underlying method directly.
# ---------------------------------------------------------------------------
class TestHitTestingApiPost:
"""Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator."""
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_success(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app,
):
"""Test successful hit testing request."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": "test query", "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
api = HitTestingApi()
# Skip billing decorator via __wrapped__
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
assert response["query"] == "test query"
mock_hit_svc.retrieve.assert_called_once()
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_with_retrieval_model(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app,
):
"""Test hit testing with custom retrieval model."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
mock_ns.payload = {
"query": "complex query",
"retrieval_model": retrieval_model,
"external_retrieval_model": {"provider": "custom"},
}
with app.test_request_context():
api = HitTestingApi()
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
assert response["query"] == "complex query"
call_kwargs = mock_hit_svc.retrieve.call_args
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_dataset_not_found(
self,
mock_current_user,
mock_dataset_svc,
mock_ns,
app,
):
"""Test hit testing with non-existent dataset."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset_svc.get_dataset.return_value = None
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
api = HitTestingApi()
with pytest.raises(NotFound):
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_no_dataset_permission(
self,
mock_current_user,
mock_dataset_svc,
mock_ns,
app,
):
"""Test hit testing when user lacks dataset permission."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(
"Access denied"
)
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
api = HitTestingApi()
with pytest.raises(Forbidden):
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)

View File

@ -0,0 +1,534 @@
"""
Unit tests for Service API Metadata controllers.
Tests coverage for:
- DatasetMetadataCreateServiceApi (post, get)
- DatasetMetadataServiceApi (patch, delete)
- DatasetMetadataBuiltInFieldServiceApi (get)
- DatasetMetadataBuiltInFieldActionServiceApi (post)
- DocumentMetadataEditServiceApi (post)
Decorator strategy:
- ``@cloud_edition_billing_rate_limit_check`` preserves ``__wrapped__``
via ``functools.wraps`` call the unwrapped method directly.
- Methods without billing decorators call directly; only patch ``db``,
services, and ``current_user``.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.metadata import (
DatasetMetadataBuiltInFieldActionServiceApi,
DatasetMetadataBuiltInFieldServiceApi,
DatasetMetadataCreateServiceApi,
DatasetMetadataServiceApi,
DocumentMetadataEditServiceApi,
)
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_tenant():
tenant = Mock()
tenant.id = str(uuid.uuid4())
return tenant
@pytest.fixture
def mock_dataset():
dataset = Mock()
dataset.id = str(uuid.uuid4())
return dataset
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# DatasetMetadataCreateServiceApi
# ---------------------------------------------------------------------------
class TestDatasetMetadataCreatePost:
"""Tests for DatasetMetadataCreateServiceApi.post().
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``
which preserves ``__wrapped__``.
"""
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.marshal")
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_create_metadata_success(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
mock_tenant,
mock_dataset,
):
"""Test successful metadata creation."""
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_metadata = Mock()
mock_meta_svc.create_metadata.return_value = mock_metadata
mock_marshal.return_value = {"id": "meta-1", "name": "Author"}
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata",
method="POST",
json={"type": "string", "name": "Author"},
):
api = DatasetMetadataCreateServiceApi()
response, status = self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)
assert status == 201
mock_meta_svc.create_metadata.assert_called_once()
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_create_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata",
method="POST",
json={"type": "string", "name": "Author"},
):
api = DatasetMetadataCreateServiceApi()
with pytest.raises(NotFound):
self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)
class TestDatasetMetadataCreateGet:
"""Tests for DatasetMetadataCreateServiceApi.get()."""
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_get_metadata_success(
self,
mock_dataset_svc,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test successful metadata list retrieval."""
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_meta_svc.get_dataset_metadatas.return_value = [{"id": "m1"}]
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata",
method="GET",
):
api = DatasetMetadataCreateServiceApi()
response, status = api.get(
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)
assert status == 200
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_get_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata",
method="GET",
):
api = DatasetMetadataCreateServiceApi()
with pytest.raises(NotFound):
api.get(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id)
# ---------------------------------------------------------------------------
# DatasetMetadataServiceApi
# ---------------------------------------------------------------------------
class TestDatasetMetadataServiceApiPatch:
"""Tests for DatasetMetadataServiceApi.patch().
``patch`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
"""
@staticmethod
def _call_patch(api, **kwargs):
return _unwrap(api.patch)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.marshal")
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_update_metadata_name_success(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
mock_tenant,
mock_dataset,
):
"""Test successful metadata name update."""
metadata_id = str(uuid.uuid4())
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_meta_svc.update_metadata_name.return_value = Mock()
mock_marshal.return_value = {"id": metadata_id, "name": "New Name"}
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
method="PATCH",
json={"name": "New Name"},
):
api = DatasetMetadataServiceApi()
response, status = self._call_patch(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
metadata_id=metadata_id,
)
assert status == 200
mock_meta_svc.update_metadata_name.assert_called_once()
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_update_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
metadata_id = str(uuid.uuid4())
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
method="PATCH",
json={"name": "x"},
):
api = DatasetMetadataServiceApi()
with pytest.raises(NotFound):
self._call_patch(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
metadata_id=metadata_id,
)
class TestDatasetMetadataServiceApiDelete:
"""Tests for DatasetMetadataServiceApi.delete().
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
"""
@staticmethod
def _call_delete(api, **kwargs):
return _unwrap(api.delete)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_delete_metadata_success(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test successful metadata deletion."""
metadata_id = str(uuid.uuid4())
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_meta_svc.delete_metadata.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
method="DELETE",
):
api = DatasetMetadataServiceApi()
response = self._call_delete(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
metadata_id=metadata_id,
)
assert response == ("", 204)
mock_meta_svc.delete_metadata.assert_called_once()
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_delete_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
metadata_id = str(uuid.uuid4())
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/{metadata_id}",
method="DELETE",
):
api = DatasetMetadataServiceApi()
with pytest.raises(NotFound):
self._call_delete(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
metadata_id=metadata_id,
)
# ---------------------------------------------------------------------------
# DatasetMetadataBuiltInFieldServiceApi
# ---------------------------------------------------------------------------
class TestDatasetMetadataBuiltInFieldGet:
"""Tests for DatasetMetadataBuiltInFieldServiceApi.get()."""
@patch("controllers.service_api.dataset.metadata.MetadataService")
def test_get_built_in_fields_success(
self,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test successful built-in fields retrieval."""
mock_meta_svc.get_built_in_fields.return_value = [
{"name": "source", "type": "string"},
]
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/built-in",
method="GET",
):
api = DatasetMetadataBuiltInFieldServiceApi()
response, status = api.get(
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)
assert status == 200
assert "fields" in response
# ---------------------------------------------------------------------------
# DatasetMetadataBuiltInFieldActionServiceApi
# ---------------------------------------------------------------------------
class TestDatasetMetadataBuiltInFieldAction:
"""Tests for DatasetMetadataBuiltInFieldActionServiceApi.post().
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
"""
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_enable_built_in_field(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test enabling built-in metadata field."""
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/built-in/enable",
method="POST",
):
api = DatasetMetadataBuiltInFieldActionServiceApi()
response, status = self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
action="enable",
)
assert status == 200
assert response["result"] == "success"
mock_meta_svc.enable_built_in_field.assert_called_once_with(mock_dataset)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_disable_built_in_field(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test disabling built-in metadata field."""
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/built-in/disable",
method="POST",
):
api = DatasetMetadataBuiltInFieldActionServiceApi()
response, status = self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
action="disable",
)
assert status == 200
mock_meta_svc.disable_built_in_field.assert_called_once_with(mock_dataset)
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_action_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/metadata/built-in/enable",
method="POST",
):
api = DatasetMetadataBuiltInFieldActionServiceApi()
with pytest.raises(NotFound):
self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
action="enable",
)
# ---------------------------------------------------------------------------
# DocumentMetadataEditServiceApi
# ---------------------------------------------------------------------------
class TestDocumentMetadataEditPost:
"""Tests for DocumentMetadataEditServiceApi.post().
``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``.
"""
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@patch("controllers.service_api.dataset.metadata.current_user")
def test_update_documents_metadata_success(
self,
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test successful documents metadata update."""
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_meta_svc.update_documents_metadata.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/metadata",
method="POST",
json={"operation_data": []},
):
api = DocumentMetadataEditServiceApi()
response, status = self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)
assert status == 200
assert response["result"] == "success"
@patch("controllers.service_api.dataset.metadata.DatasetService")
def test_update_documents_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
mock_tenant,
mock_dataset,
):
"""Test 404 when dataset not found."""
mock_dataset_svc.get_dataset.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/metadata",
method="POST",
json={"operation_data": []},
):
api = DocumentMetadataEditServiceApi()
with pytest.raises(NotFound):
self._call_post(
api,
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
)

View File

@ -0,0 +1,69 @@
"""
Unit tests for Service API Index endpoint
"""
from unittest.mock import MagicMock, patch
import pytest
from controllers.service_api.index import IndexApi
class TestIndexApi:
"""Test suite for IndexApi resource."""
@patch("controllers.service_api.index.dify_config")
def test_get_returns_api_info(self, mock_config, app):
"""Test that GET returns API metadata with correct structure."""
# Arrange
mock_config.project.version = "1.0.0-test"
# Act
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert response["welcome"] == "Dify OpenAPI"
assert response["api_version"] == "v1"
assert response["server_version"] == "1.0.0-test"
def test_get_response_has_required_fields(self, app):
"""Test that response contains all required fields."""
# Arrange
mock_config = MagicMock()
mock_config.project.version = "1.11.4"
# Act
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert "welcome" in response
assert "api_version" in response
assert "server_version" in response
assert isinstance(response["welcome"], str)
assert isinstance(response["api_version"], str)
assert isinstance(response["server_version"], str)
@pytest.mark.parametrize("version", ["0.0.1", "1.0.0", "2.0.0-beta", "1.11.4"])
def test_get_returns_correct_version(self, app, version):
"""Test that server_version matches config version."""
# Arrange
mock_config = MagicMock()
mock_config.project.version = version
# Act
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert response["server_version"] == version

View File

@ -0,0 +1,270 @@
"""
Unit tests for Service API Site controller
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.service_api.app.site import AppSiteApi
from models.account import TenantStatus
from models.model import App, Site
from tests.unit_tests.conftest import setup_mock_tenant_account_query
class TestAppSiteApi:
"""Test suite for AppSiteApi"""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model with tenant."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
mock_tenant = Mock()
mock_tenant.id = app.tenant_id
mock_tenant.status = TenantStatus.NORMAL
app.tenant = mock_tenant
return app
@pytest.fixture
def mock_site(self):
"""Create a mock Site model."""
site = Mock(spec=Site)
site.id = str(uuid.uuid4())
site.app_id = str(uuid.uuid4())
site.title = "Test Site"
site.icon = "icon-url"
site.icon_background = "#ffffff"
site.description = "Site description"
site.copyright = "Copyright 2024"
site.privacy_policy = "Privacy policy text"
site.custom_disclaimer = "Custom disclaimer"
site.default_language = "en-US"
site.prompt_public = True
site.show_workflow_steps = True
site.use_icon_as_answer_icon = False
site.chat_color_theme = "light"
site.chat_color_theme_inverted = False
site.icon_type = "image"
site.created_at = "2024-01-01T00:00:00"
site.updated_at = "2024-01-01T00:00:00"
return site
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_success(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test successful retrieval of site configuration."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
# Mock wraps.db for authentication
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site.db for site query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
response = api.get()
# Assert
assert response["title"] == "Test Site"
assert response["icon"] == "icon-url"
assert response["description"] == "Site description"
mock_db.session.query.assert_called_once_with(Site)
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_not_found(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
):
"""Test that Forbidden is raised when site is not found."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query to return None
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_tenant_archived(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test that Forbidden is raised when tenant is archived."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
# Set tenant status to archived AFTER authentication
mock_app_model.tenant.status = TenantStatus.ARCHIVE
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_queries_by_app_id(
self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model
):
"""Test that site is queried using the app model's id."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
mock_site = Mock(spec=Site)
mock_site.id = str(uuid.uuid4())
mock_site.app_id = mock_app_model.id
mock_site.title = "Test Site"
mock_site.icon = "icon-url"
mock_site.icon_background = "#ffffff"
mock_site.description = "Site description"
mock_site.copyright = "Copyright 2024"
mock_site.privacy_policy = "Privacy policy text"
mock_site.custom_disclaimer = "Custom disclaimer"
mock_site.default_language = "en-US"
mock_site.prompt_public = True
mock_site.show_workflow_steps = True
mock_site.use_icon_as_answer_icon = False
mock_site.chat_color_theme = "light"
mock_site.chat_color_theme_inverted = False
mock_site.icon_type = "image"
mock_site.created_at = "2024-01-01T00:00:00"
mock_site.updated_at = "2024-01-01T00:00:00"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
api.get()
# Assert
# The query was executed successfully (site returned), which validates the correct query was made
mock_db.session.query.assert_called_once_with(Site)

View File

@ -0,0 +1,550 @@
"""
Unit tests for Service API wraps (authentication decorators)
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.service_api.wraps import (
DatasetApiResource,
FetchUserArg,
WhereisUserArg,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
validate_and_get_api_token,
validate_app_token,
validate_dataset_token,
)
from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
from models.model import ApiToken
from tests.unit_tests.conftest import (
setup_mock_dataset_tenant_query,
setup_mock_tenant_account_query,
)
class TestValidateAndGetApiToken:
"""Test suite for validate_and_get_api_token function"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def test_missing_authorization_header(self, app):
"""Test that Unauthorized is raised when Authorization header is missing."""
# Arrange
with app.test_request_context("/", method="GET"):
# No Authorization header
# Act & Assert
with pytest.raises(Unauthorized) as exc_info:
validate_and_get_api_token("app")
assert "Authorization header must be provided" in str(exc_info.value)
def test_invalid_auth_scheme(self, app):
"""Test that Unauthorized is raised when auth scheme is not Bearer."""
# Arrange
with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}):
# Act & Assert
with pytest.raises(Unauthorized) as exc_info:
validate_and_get_api_token("app")
assert "Authorization scheme must be 'Bearer'" in str(exc_info.value)
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
"""Test that valid token returns the ApiToken object."""
# Arrange
mock_api_token = Mock(spec=ApiToken)
mock_api_token.token = "valid_token_123"
mock_api_token.type = "app"
mock_cache_instance = Mock()
mock_cache_instance.get.return_value = None # Cache miss
mock_cache_cls.get = mock_cache_instance.get
mock_fetch_token.return_value = mock_api_token
# Act
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer valid_token_123"}):
result = validate_and_get_api_token("app")
# Assert
assert result == mock_api_token
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
"""Test that invalid token raises Unauthorized."""
# Arrange
from werkzeug.exceptions import Unauthorized
mock_cache_instance = Mock()
mock_cache_instance.get.return_value = None # Cache miss
mock_cache_cls.get = mock_cache_instance.get
mock_fetch_token.side_effect = Unauthorized("Access token is invalid")
# Act & Assert
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer invalid_token"}):
with pytest.raises(Unauthorized) as exc_info:
validate_and_get_api_token("app")
assert "Access token is invalid" in str(exc_info.value)
class TestValidateAppToken:
"""Test suite for validate_app_token decorator"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.current_app")
def test_valid_app_token_allows_access(
self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app
):
"""Test that valid app token allows access to decorated view."""
# Arrange
# Use standard Mock for login_manager to avoid AsyncMockMixin warnings
mock_current_app.login_manager = Mock()
mock_api_token = Mock()
mock_api_token.app_id = str(uuid.uuid4())
mock_api_token.tenant_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_app = Mock()
mock_app.id = mock_api_token.app_id
mock_app.status = "normal"
mock_app.enable_api = True
mock_app.tenant_id = mock_api_token.tenant_id
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_tenant.id = mock_api_token.tenant_id
mock_account = Mock()
mock_account.id = str(uuid.uuid4())
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
mock_account,
]
# Mock the tenant owner query
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
@validate_app_token
def protected_view(app_model):
return {"success": True, "app_id": app_model.id}
# Act
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer test_token"}):
result = protected_view()
# Assert
assert result["success"] is True
assert result["app_id"] == mock_app.id
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app):
"""Test that Forbidden is raised when app no longer exists."""
# Arrange
mock_api_token = Mock()
mock_api_token.app_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_db.session.query.return_value.where.return_value.first.return_value = None
@validate_app_token
def protected_view(**kwargs):
return {"success": True}
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
protected_view()
assert "no longer exists" in str(exc_info.value)
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app):
"""Test that Forbidden is raised when app status is abnormal."""
# Arrange
mock_api_token = Mock()
mock_api_token.app_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_app = Mock()
mock_app.status = "abnormal"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
@validate_app_token
def protected_view(**kwargs):
return {"success": True}
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
protected_view()
assert "status is abnormal" in str(exc_info.value)
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app):
"""Test that Forbidden is raised when app API is disabled."""
# Arrange
mock_api_token = Mock()
mock_api_token.app_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_app = Mock()
mock_app.status = "normal"
mock_app.enable_api = False
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
@validate_app_token
def protected_view(**kwargs):
return {"success": True}
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
protected_view()
assert "API service has been disabled" in str(exc_info.value)
class TestCloudEditionBillingResourceCheck:
"""Test suite for cloud_edition_billing_resource_check decorator"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app):
"""Test that request is allowed when under resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 5
mock_get_features.return_value = mock_features
@cloud_edition_billing_resource_check("members", "app")
def add_member():
return "member_added"
# Act
with app.test_request_context("/", method="GET"):
result = add_member()
# Assert
assert result == "member_added"
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app):
"""Test that Forbidden is raised when at resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 10
mock_get_features.return_value = mock_features
@cloud_edition_billing_resource_check("members", "app")
def add_member():
return "member_added"
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
add_member()
assert "members has reached the limit" in str(exc_info.value)
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app):
"""Test that request is allowed when billing is disabled."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = False
mock_get_features.return_value = mock_features
@cloud_edition_billing_resource_check("members", "app")
def add_member():
return "member_added"
# Act
with app.test_request_context("/", method="GET"):
result = add_member()
# Assert
assert result == "member_added"
class TestCloudEditionBillingKnowledgeLimitCheck:
"""Test suite for cloud_edition_billing_knowledge_limit_check decorator"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app):
"""Test that add_segment is rejected in SANDBOX plan."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
mock_get_features.return_value = mock_features
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
def add_segment():
return "segment_added"
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
add_segment()
assert "upgrade to a paid plan" in str(exc_info.value)
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app):
"""Test that non-add_segment operations are allowed in SANDBOX."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
mock_get_features.return_value = mock_features
@cloud_edition_billing_knowledge_limit_check("search", "dataset")
def search():
return "search_results"
# Act
with app.test_request_context("/", method="GET"):
result = search()
# Assert
assert result == "search_results"
class TestCloudEditionBillingRateLimitCheck:
"""Test suite for cloud_edition_billing_rate_limit_check decorator"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app):
"""Test that request is allowed when within rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_rate_limit = Mock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 100
mock_get_rate_limit.return_value = mock_rate_limit
# Mock redis operations
with patch("controllers.service_api.wraps.redis_client") as mock_redis:
mock_redis.zcard.return_value = 50 # Under limit
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def knowledge_request():
return "success"
# Act
with app.test_request_context("/", method="GET"):
result = knowledge_request()
# Assert
assert result == "success"
mock_redis.zadd.assert_called_once()
mock_redis.zremrangebyscore.assert_called_once()
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
@patch("controllers.service_api.wraps.db")
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app):
"""Test that Forbidden is raised when over rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_rate_limit = Mock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 10
mock_rate_limit.subscription_plan = "pro"
mock_get_rate_limit.return_value = mock_rate_limit
with patch("controllers.service_api.wraps.redis_client") as mock_redis:
mock_redis.zcard.return_value = 15 # Over limit
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def knowledge_request():
return "success"
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(Forbidden) as exc_info:
knowledge_request()
assert "rate limit" in str(exc_info.value)
class TestValidateDatasetToken:
"""Test suite for validate_dataset_token decorator"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.current_app")
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app):
"""Test that valid dataset token allows access."""
# Arrange
# Use standard Mock for login_manager
mock_current_app.login_manager = Mock()
tenant_id = str(uuid.uuid4())
mock_api_token = Mock()
mock_api_token.tenant_id = tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.id = tenant_id
mock_tenant.status = TenantStatus.NORMAL
mock_ta = Mock()
mock_ta.account_id = str(uuid.uuid4())
mock_account = Mock()
mock_account.id = mock_ta.account_id
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
# Mock the account query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
@validate_dataset_token
def protected_view(tenant_id):
return {"success": True, "tenant_id": tenant_id}
# Act
with app.test_request_context("/", method="GET", headers={"Authorization": "Bearer test_token"}):
result = protected_view()
# Assert
assert result["success"] is True
assert result["tenant_id"] == tenant_id
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app):
"""Test that NotFound is raised when dataset doesn't exist."""
# Arrange
mock_api_token = Mock()
mock_api_token.tenant_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_db.session.query.return_value.where.return_value.first.return_value = None
@validate_dataset_token
def protected_view(dataset_id=None, **kwargs):
return {"success": True}
# Act & Assert
with app.test_request_context("/", method="GET"):
with pytest.raises(NotFound) as exc_info:
protected_view(dataset_id=str(uuid.uuid4()))
assert "Dataset not found" in str(exc_info.value)
class TestFetchUserArg:
"""Test suite for FetchUserArg model"""
def test_fetch_user_arg_defaults(self):
"""Test FetchUserArg default values."""
# Arrange & Act
arg = FetchUserArg(fetch_from=WhereisUserArg.JSON)
# Assert
assert arg.fetch_from == WhereisUserArg.JSON
assert arg.required is False
def test_fetch_user_arg_required(self):
"""Test FetchUserArg with required=True."""
# Arrange & Act
arg = FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)
# Assert
assert arg.fetch_from == WhereisUserArg.QUERY
assert arg.required is True
class TestDatasetApiResource:
"""Test suite for DatasetApiResource base class"""
def test_method_decorators_has_validate_dataset_token(self):
"""Test that DatasetApiResource has validate_dataset_token in method_decorators."""
# Assert
assert validate_dataset_token in DatasetApiResource.method_decorators
def test_get_dataset_method_exists(self):
"""Test that get_dataset method exists on DatasetApiResource."""
# Assert
assert hasattr(DatasetApiResource, "get_dataset")