mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
test: unit test case for controllers.console.app module (#32247)
This commit is contained in:
parent
a59c54b3e7
commit
a480e9beb1
@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.console.app import annotation as annotation_module
|
||||
|
||||
|
||||
def test_annotation_reply_payload_valid():
|
||||
"""Test AnnotationReplyPayload with valid data."""
|
||||
payload = annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.score_threshold == 0.5
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-3-small"
|
||||
|
||||
|
||||
def test_annotation_setting_update_payload_valid():
|
||||
"""Test AnnotationSettingUpdatePayload with valid data."""
|
||||
payload = annotation_module.AnnotationSettingUpdatePayload(
|
||||
score_threshold=0.75,
|
||||
)
|
||||
assert payload.score_threshold == 0.75
|
||||
|
||||
|
||||
def test_annotation_list_query_defaults():
|
||||
"""Test AnnotationListQuery with default parameters."""
|
||||
query = annotation_module.AnnotationListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
assert query.keyword == ""
|
||||
|
||||
|
||||
def test_annotation_list_query_custom_page():
|
||||
"""Test AnnotationListQuery with custom page."""
|
||||
query = annotation_module.AnnotationListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_annotation_list_query_with_keyword():
|
||||
"""Test AnnotationListQuery with keyword."""
|
||||
query = annotation_module.AnnotationListQuery(keyword="test")
|
||||
assert query.keyword == "test"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_message_id():
|
||||
"""Test CreateAnnotationPayload with message ID."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
question="What is AI?",
|
||||
)
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert payload.question == "What is AI?"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_text():
|
||||
"""Test CreateAnnotationPayload with text content."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
question="What is ML?",
|
||||
answer="Machine learning is...",
|
||||
)
|
||||
assert payload.question == "What is ML?"
|
||||
assert payload.answer == "Machine learning is..."
|
||||
|
||||
|
||||
def test_update_annotation_payload():
|
||||
"""Test UpdateAnnotationPayload."""
|
||||
payload = annotation_module.UpdateAnnotationPayload(
|
||||
question="Updated question",
|
||||
answer="Updated answer",
|
||||
)
|
||||
assert payload.question == "Updated question"
|
||||
assert payload.answer == "Updated answer"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_enable():
|
||||
"""Test AnnotationReplyStatusQuery with enable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
|
||||
assert query.action == "enable"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_disable():
|
||||
"""Test AnnotationReplyStatusQuery with disable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
|
||||
assert query.action == "disable"
|
||||
|
||||
|
||||
def test_annotation_file_payload_valid():
|
||||
"""Test AnnotationFilePayload with valid message ID."""
|
||||
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
@ -13,6 +13,9 @@ from pandas.errors import ParserError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class TestAnnotationImportRateLimiting:
|
||||
@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-minute rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-minute limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
|
||||
@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-hour rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-hour limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate being under both limits
|
||||
mock_redis.zcard.return_value = 2
|
||||
@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that concurrent task limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate max concurrent tasks already running
|
||||
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
|
||||
@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within concurrency limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate being under concurrent task limit
|
||||
mock_redis.zcard.return_value = 1
|
||||
@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
|
||||
"""Test that old/stale job entries are removed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
mock_redis.zcard.return_value = 0
|
||||
|
||||
@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too many records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with too many records
|
||||
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too few valid records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with only header (no data rows)
|
||||
csv_content = "question,answer\n"
|
||||
@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Any content is fine once we force ParserError
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create valid CSV
|
||||
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
|
||||
@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation:
|
||||
class TestAnnotationImportTaskOptimization:
|
||||
"""Test optimizations in batch import task."""
|
||||
|
||||
def test_task_has_timeout_configured(self):
|
||||
"""Test that task has proper timeout configuration."""
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
# Verify task configuration
|
||||
assert hasattr(batch_import_annotations_task, "time_limit")
|
||||
assert hasattr(batch_import_annotations_task, "soft_time_limit")
|
||||
|
||||
# Check timeout values are reasonable
|
||||
# Hard limit should be 6 minutes (360s)
|
||||
# Soft limit should be 5 minutes (300s)
|
||||
# Note: actual values depend on Celery configuration
|
||||
def test_task_is_registered_with_queue(self):
|
||||
"""Test that task is registered with the correct queue."""
|
||||
assert hasattr(batch_import_annotations_task, "apply_async")
|
||||
assert hasattr(batch_import_annotations_task, "delay")
|
||||
|
||||
|
||||
class TestConfigurationValues:
|
||||
|
||||
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
@ -0,0 +1,585 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import (
|
||||
annotation as annotation_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
completion as completion_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
message as message_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
ops_trace as ops_trace_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
site as site_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
statistic as statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_app_log as workflow_app_log_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_draft_variable as workflow_draft_variable_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_statistic as workflow_statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_trigger as workflow_trigger_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
wraps as wraps_module,
|
||||
)
|
||||
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
|
||||
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
|
||||
from controllers.console.app.site import AppSiteUpdatePayload
|
||||
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
|
||||
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
|
||||
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
|
||||
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
|
||||
from controllers.console.app.workflow_trigger import Parser, ParserEnable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
def test_chat_message_payload_uuid_validation(self):
|
||||
payload = ChatMessagePayload(
|
||||
inputs={},
|
||||
model_config={},
|
||||
query="hi",
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
parent_message_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert payload.query == "hi"
|
||||
|
||||
def test_completion_api_success(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: {"text": "ok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
lambda response: {"result": response},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(
|
||||
completion_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
def test_ops_trace_config_payload(self):
|
||||
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
|
||||
assert payload.tracing_config["api_key"] == "k"
|
||||
|
||||
def test_trace_app_config_get_empty(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"get_tracing_app_config",
|
||||
lambda **_kwargs: None,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
result = method(app_id="app-1")
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"create_tracing_app_config",
|
||||
lambda **_kwargs: {"error": True},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"delete_tracing_app_config",
|
||||
lambda **_kwargs: False,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
def test_site_default_language_validation(self):
|
||||
payload = AppSiteUpdatePayload(default_language="en-US")
|
||||
assert payload.default_language == "en-US"
|
||||
|
||||
def test_app_site_update_post(self, app, monkeypatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
def test_workflow_app_log_query_detail_bool(self):
|
||||
query = WorkflowAppLogQuery(detail="true")
|
||||
assert query.detail is True
|
||||
|
||||
def test_workflow_app_log_api_get(self, app, monkeypatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
def fake_get_paginate(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_log_module.WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
fake_get_paginate,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
def test_workflow_variable_collection_get(self, app, monkeypatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyDraftService:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
def list_variables_without_values(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
class DummyWorkflowService:
|
||||
def is_workflow_exist(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
def test_workflow_statistic_blank_to_none(self):
|
||||
query = WorkflowStatisticQuery(start="", end="")
|
||||
assert query.start is None
|
||||
assert query.end is None
|
||||
|
||||
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
|
||||
assert enable_payload.enable_trigger is True
|
||||
|
||||
def test_webhook_trigger_api_get(self, app, monkeypatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
trigger = MagicMock()
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = trigger
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
|
||||
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
def test_mcp_server_update_payload(self):
|
||||
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
),
|
||||
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
|
||||
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
|
||||
]
|
||||
assert len(payloads) == 3
|
||||
assert all(p is not None for p in payloads)
|
||||
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechLanageServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
|
||||
)
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import audio as audio_module
|
||||
from controllers.console.app.error import AudioTooLargeError
|
||||
from services.errors.audio import AudioTooLargeServiceError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import conversation as conversation_module
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account():
|
||||
return SimpleNamespace(timezone="UTC", id="u1")
|
||||
|
||||
|
||||
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/completion-conversations",
|
||||
method="GET",
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.ChatConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
|
||||
|
||||
assert result is conversation
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(conversation)
|
||||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
|
||||
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import generator as generator_module
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _model_config_payload():
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
|
||||
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
class _Service:
|
||||
def get_draft_workflow(self, app_model):
|
||||
return workflow
|
||||
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
|
||||
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-code-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method()
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{"id": "node-1", "data": {"type": "code"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"code": "x"}
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
"instruction_modify_legacy",
|
||||
lambda **_kwargs: {"instruction": "ok"},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "old",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
|
||||
def test_instruction_template_prompt(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "prompt"},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert "data" in response
|
||||
|
||||
|
||||
def test_instruction_template_invalid_type(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "unknown"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
limit=50,
|
||||
)
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery with defaults."""
|
||||
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
|
||||
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery converts empty first_id to None."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="",
|
||||
)
|
||||
assert query.first_id is None
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with like rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="like",
|
||||
content="Good answer",
|
||||
)
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Good answer"
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with dislike rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="dislike",
|
||||
)
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
|
||||
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload without rating."""
|
||||
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.rating is None
|
||||
|
||||
|
||||
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with default format."""
|
||||
query = message_module.FeedbackExportQuery()
|
||||
assert query.format == "csv"
|
||||
assert query.from_source is None
|
||||
|
||||
|
||||
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with JSON format."""
|
||||
query = message_module.FeedbackExportQuery(format="json")
|
||||
assert query.format == "json"
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as true string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="true")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as false string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="false")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 1."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="1")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 0."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="0")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with rating filter."""
|
||||
query = message_module.FeedbackExportQuery(rating="like")
|
||||
assert query.rating == "like"
|
||||
|
||||
|
||||
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test AnnotationCountResponse creation."""
|
||||
response = message_module.AnnotationCountResponse(count=10)
|
||||
assert response.count == 10
|
||||
|
||||
|
||||
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test SuggestedQuestionsResponse creation."""
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.CHAT.value,
|
||||
is_agent=False,
|
||||
app_model_config_id=None,
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {"pre_prompt": "hi"},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
def _from_model_config_dict(self, model_config):
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.id = "config-1"
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
send_mock.assert_called_once()
|
||||
assert app_model.app_model_config_id == "config-1"
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.AGENT_CHAT.value,
|
||||
is_agent=True,
|
||||
app_model_config_id="config-0",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
|
||||
original_config.agent_mode = json.dumps(
|
||||
{
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
}
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {
|
||||
"pre_prompt": "hi",
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
|
||||
|
||||
class _ParamManager:
|
||||
def __init__(self, **_kwargs):
|
||||
self.delete_called = False
|
||||
|
||||
def decrypt_tool_parameters(self, _value):
|
||||
return {"secret": "decrypted"}
|
||||
|
||||
def mask_tool_parameters(self, _value):
|
||||
return {"secret": "masked"}
|
||||
|
||||
def encrypt_tool_parameters(self, _value):
|
||||
return {"secret": "encrypted"}
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
self.delete_called = True
|
||||
|
||||
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
|
||||
assert response["result"] == "success"
|
||||
@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
|
||||
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
|
||||
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
|
||||
|
||||
|
||||
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["date"] == "2024-01-03"
|
||||
assert data["data"][0]["token_count"] == 10
|
||||
assert data["data"][0]["total_price"] == 0.25
|
||||
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
api = statistic_module.AverageSessionInteractionStatistic()
|
||||
method = _unwrap(api.get)
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
def mock_parse(*args, **kwargs):
|
||||
raise ValueError("Invalid time range")
|
||||
|
||||
_install_db(monkeypatch, [])
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", message_count=10),
|
||||
SimpleNamespace(date="2024-01-02", message_count=15),
|
||||
SimpleNamespace(date="2024-01-03", message_count=12),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: ("s", "e"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
|
||||
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
|
||||
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
|
||||
|
||||
|
||||
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config = object()
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
]
|
||||
build_mock = Mock(return_value=file_list)
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
|
||||
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
|
||||
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
|
||||
|
||||
assert result == file_list
|
||||
build_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert exc.value.code == 415
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
data="[]",
|
||||
content_type="application/json",
|
||||
):
|
||||
response, status = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert status == 400
|
||||
assert response["message"] == "Invalid JSON data"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="h",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
|
||||
)
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise workflow_module.WorkflowHashNotEqualError()
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=_raise)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/advanced-chat/workflows/draft/run",
|
||||
method="POST",
|
||||
json={"inputs": {}},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import wraps as wraps_module
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
assert handler(app_id="app-1") == "app-1"
|
||||
|
||||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(AppNotFoundError):
|
||||
handler(app_id="app-1")
|
||||
|
||||
|
||||
def test_get_app_model_requires_app_id() -> None:
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler()
|
||||
@ -1,13 +1,483 @@
|
||||
"""Final working unit tests for admin endpoints - tests business logic directly."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import InsertExploreAppPayload
|
||||
from models.model import App, RecommendedApp
|
||||
from controllers.console.admin import (
|
||||
DeleteExploreBannerApi,
|
||||
InsertExploreAppApi,
|
||||
InsertExploreAppListApi,
|
||||
InsertExploreAppPayload,
|
||||
InsertExploreBannerApi,
|
||||
InsertExploreBannerPayload,
|
||||
)
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_only_edition_cloud(mocker):
|
||||
"""
|
||||
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.wraps.dify_config.EDITION",
|
||||
new="CLOUD",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_auth(mocker):
|
||||
"""
|
||||
Provide valid admin authentication for controller tests.
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.admin.dify_config.ADMIN_API_KEY",
|
||||
"test-admin-key",
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.admin.extract_access_token",
|
||||
return_value="test-admin-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_console_payload(mocker):
|
||||
payload = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_banner_payload(mocker):
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value={
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(mocker):
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = DeleteExploreBannerApi()
|
||||
|
||||
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.delete(uuid.uuid4())
|
||||
|
||||
def test_delete_banner_success(self, mocker, mock_admin_auth):
|
||||
mock_banner = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
|
||||
)
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
class TestInsertExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreBannerApi()
|
||||
|
||||
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_banner_payload_valid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
model = InsertExploreBannerPayload.model_validate(payload)
|
||||
assert model.language == "en-US"
|
||||
|
||||
def test_banner_payload_invalid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "invalid-lang",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreBannerPayload.model_validate(payload)
|
||||
|
||||
|
||||
class TestInsertExploreAppApiDelete:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppApi()
|
||||
|
||||
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=Mock(return_value=False),
|
||||
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
|
||||
),
|
||||
)
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app from explore that has a trial app."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_trial = Mock()
|
||||
|
||||
# Mock session context manager and its execute
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
# Set up side effects for execute calls
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
|
||||
Mock(scalar_one_or_none=lambda: mock_trial),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is False
|
||||
|
||||
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app that has installed apps in other tenants."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_installed_app = Mock(spec=InstalledApp)
|
||||
|
||||
# Mock session
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert mock_session.delete.called
|
||||
|
||||
|
||||
class TestInsertExploreAppListApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppListApi()
|
||||
|
||||
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.post()
|
||||
|
||||
def test_create_recommended_app(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
# db.session.execute → fetch App
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_app),
|
||||
)
|
||||
|
||||
# session_factory.create_session → recommended_app lookup
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_site_data_overrides_payload(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
site = Mock()
|
||||
site.description = "Site Desc"
|
||||
site.copyright = "Site Copyright"
|
||||
site.privacy_policy = "Site Privacy"
|
||||
site.custom_disclaimer = "Site Disclaimer"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = site
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
commit_spy.assert_called_once()
|
||||
|
||||
def test_create_trial_app_when_can_trial_enabled(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 5
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
self.api.post()
|
||||
|
||||
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
|
||||
|
||||
def test_update_recommended_app_with_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app when trial is enabled."""
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 10
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
mock_app.tenant_id = "tenant-123"
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app_without_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app without trial enabled."""
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
|
||||
class TestInsertExploreAppPayload:
|
||||
|
||||
138
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
138
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
@ -0,0 +1,138 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.apikey import (
|
||||
BaseApiKeyListResource,
|
||||
BaseApiKeyResource,
|
||||
_get_resource,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = True
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_non_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = False
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_mock():
|
||||
with patch("controllers.console.apikey.db") as mock_db:
|
||||
mock_db.session = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_permissions():
|
||||
with patch(
|
||||
"controllers.console.apikey.edit_permission_required",
|
||||
lambda f: f,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class DummyApiKeyListResource(BaseApiKeyListResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class DummyApiKeyResource(BaseApiKeyResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
class TestGetResource:
|
||||
def test_get_resource_success(self):
|
||||
fake_resource = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
|
||||
|
||||
result = _get_resource("rid", "tid", MagicMock)
|
||||
assert result == fake_resource
|
||||
|
||||
def test_get_resource_not_found(self):
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
patch("controllers.console.apikey.flask_restx.abort") as abort,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
_get_resource("rid", "tid", MagicMock)
|
||||
|
||||
abort.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseApiKeyListResource:
|
||||
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyListResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
|
||||
assert "items" in result
|
||||
|
||||
|
||||
class TestBaseApiKeyResource:
|
||||
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Forbidden):
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
# flask_restx.abort raises HTTPException with message in data attribute
|
||||
assert exc_info.value.data["message"] == "API key not found"
|
||||
|
||||
def test_delete_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
patch("controllers.console.apikey.ApiTokenCache.delete"),
|
||||
):
|
||||
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
assert status == 204
|
||||
assert result == {"result": "success"}
|
||||
db_mock.session.commit.assert_called_once()
|
||||
@ -1,46 +0,0 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.secret_key = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/init")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"status": "finished"}
|
||||
|
||||
|
||||
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
|
||||
|
||||
with (
|
||||
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.post("/console/api/init", json={"password": "test-init-password"})
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_json() == {"result": "success"}
|
||||
@ -1,286 +0,0 @@
|
||||
"""Tests for remote file upload API endpoints using Flask-RESTX."""
|
||||
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Create Flask app for testing."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client with console blueprint registered."""
|
||||
from controllers.console import bp
|
||||
|
||||
app.register_blueprint(bp)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Create a mock account for testing."""
|
||||
from models import Account
|
||||
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-account-id"
|
||||
account.current_tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_ctx(app, mock_account):
|
||||
"""Context manager to set auth/tenant context in flask.g for a request."""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _ctx():
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
yield
|
||||
|
||||
return _ctx
|
||||
|
||||
|
||||
class TestGetRemoteFileInfo:
|
||||
"""Test GET /console/api/remote-files/<path:url> endpoint."""
|
||||
|
||||
def test_get_remote_file_info_success(self, app, client, mock_account):
|
||||
"""Test successful retrieval of remote file info."""
|
||||
response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.txt"),
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "text/plain"
|
||||
assert data["file_length"] == 1024
|
||||
|
||||
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
|
||||
"""Test fallback to GET when HEAD returns non-200 status."""
|
||||
head_response = httpx.Response(
|
||||
404,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
|
||||
)
|
||||
get_response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", "http://example.com/file.pdf"),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "application/pdf"
|
||||
assert data["file_length"] == 2048
|
||||
|
||||
|
||||
class TestRemoteFileUpload:
|
||||
"""Test POST /console/api/remote-files/upload endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("head_status", "use_get"),
|
||||
[
|
||||
(200, False), # HEAD succeeds
|
||||
(405, True), # HEAD fails -> fallback GET
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
|
||||
url = "http://example.com/file.pdf"
|
||||
head_resp = httpx.Response(
|
||||
head_status,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
)
|
||||
get_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
content=b"file content",
|
||||
)
|
||||
|
||||
file_info = SimpleNamespace(
|
||||
extension="pdf",
|
||||
size=1024,
|
||||
filename="file.pdf",
|
||||
mimetype="application/pdf",
|
||||
)
|
||||
uploaded_file = SimpleNamespace(
|
||||
id="uploaded-file-id",
|
||||
name="file.pdf",
|
||||
size=1024,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="test-account-id",
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=True,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("controllers.console.remote_files.FileService") as mock_file_service,
|
||||
patch(
|
||||
"controllers.console.remote_files.file_helpers.get_signed_file_url",
|
||||
return_value="http://example.com/signed-url",
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
mock_file_service.return_value.upload_file.return_value = uploaded_file
|
||||
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
p_head.assert_called_once()
|
||||
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
|
||||
p_get.assert_called_once()
|
||||
mock_file_service.return_value.upload_file.assert_called_once()
|
||||
|
||||
data = resp.get_json()
|
||||
assert data["id"] == "uploaded-file-id"
|
||||
assert data["name"] == "file.pdf"
|
||||
assert data["size"] == 1024
|
||||
assert data["extension"] == "pdf"
|
||||
assert data["url"] == "http://example.com/signed-url"
|
||||
assert data["mime_type"] == "application/pdf"
|
||||
assert data["created_by"] == "test-account-id"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("size_ok", "raises", "expected_status", "expected_msg"),
|
||||
[
|
||||
# When size check fails in controller, API returns 413 with message "File size exceeded..."
|
||||
(False, None, 413, "file size exceeded"),
|
||||
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
|
||||
(True, "unsupported", 415, "file type not allowed"),
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_errors(
|
||||
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
|
||||
):
|
||||
url = "http://example.com/x.pdf"
|
||||
head_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
|
||||
)
|
||||
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=size_ok,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
if raises == "unsupported":
|
||||
from services.errors.file import UnsupportedFileTypeError
|
||||
|
||||
with patch("controllers.console.remote_files.FileService") as mock_file_service:
|
||||
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
else:
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert expected_msg in msg.lower()
|
||||
|
||||
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
|
||||
"""Test upload when fetching of remote file fails."""
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.ssrf_proxy.head",
|
||||
side_effect=httpx.RequestError("Connection failed"),
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": "http://unreachable.com/file.pdf"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert "failed to fetch" in msg.lower()
|
||||
81
api/tests/unit_tests/controllers/console/test_feature.py
Normal file
81
api/tests/unit_tests/controllers/console/test_feature.py
Normal file
@ -0,0 +1,81 @@
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestFeatureApi:
|
||||
def test_get_tenant_features_success(self, mocker):
|
||||
from controllers.console.feature import FeatureApi
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant",
|
||||
return_value=("account_id", "tenant_123"),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True}
|
||||
}
|
||||
|
||||
api = FeatureApi()
|
||||
|
||||
raw_get = unwrap(FeatureApi.get)
|
||||
result = raw_get(api)
|
||||
|
||||
assert result == {"features": {"feature_a": True}}
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
def test_get_system_features_authenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated == True
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
fake_user.is_authenticated = True
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": True}}
|
||||
|
||||
def test_get_system_features_unauthenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated raises Unauthorized
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": False}}
|
||||
300
api/tests/unit_tests/controllers/console/test_files.py
Normal file
300
api/tests/unit_tests/controllers/console/test_files.py
Normal file
@ -0,0 +1,300 @@
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.files import (
|
||||
FileApi,
|
||||
FilePreviewApi,
|
||||
FileSupportTypeApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.testing = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_decorators():
|
||||
"""
|
||||
Make decorators no-ops so logic is directly testable
|
||||
"""
|
||||
with (
|
||||
patch("controllers.console.files.setup_required", new=lambda f: f),
|
||||
patch("controllers.console.files.login_required", new=lambda f: f),
|
||||
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
|
||||
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user():
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account_context(mock_current_user):
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
with patch("controllers.console.files.db") as db_mock:
|
||||
db_mock.engine = MagicMock()
|
||||
yield db_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(mock_db):
|
||||
with patch("controllers.console.files.FileService") as fs:
|
||||
instance = fs.return_value
|
||||
yield instance
|
||||
|
||||
|
||||
class TestFileApiGet:
|
||||
def test_get_upload_config(self, app):
|
||||
api = FileApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
data, status = get_method(api)
|
||||
|
||||
assert status == 200
|
||||
assert "file_size_limit" in data
|
||||
assert "batch_count_limit" in data
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api)
|
||||
|
||||
def test_too_many_files(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST"):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
with patch("controllers.console.files.request") as mock_request:
|
||||
mock_request.files = MagicMock()
|
||||
mock_request.files.__len__.return_value = 2
|
||||
mock_request.files.__contains__.return_value = True
|
||||
mock_request.form = MagicMock()
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api)
|
||||
|
||||
def test_filename_missing(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), ""),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api)
|
||||
|
||||
def test_successful_upload(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-123"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 1024
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-123"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = "http://example.com/preview/file-id-123"
|
||||
mock_file.source_url = "http://example.com/source/file-id-123"
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-123"
|
||||
mock_file.tenant_id = "tenant-123"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-123"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"hello"), "test.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
# Create a properly structured mock file object
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-456"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 512
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-456"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = None
|
||||
mock_file.source_url = None
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-456"
|
||||
mock_file.tenant_id = "tenant-456"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-456"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"content"), "test.txt"),
|
||||
"source": "invalid_source", # Should be normalized to None
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
# Verify that FileService was called with source=None
|
||||
mock_file_service.upload_file.assert_called_once()
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
|
||||
error = ServiceFileTooLargeError("File is too large")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api)
|
||||
|
||||
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
error = ServiceUnsupportedFileTypeError()
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "bad.exe"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api)
|
||||
|
||||
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
|
||||
|
||||
error = ServiceBlockedFileExtensionError("File extension is blocked")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "blocked.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
class TestFileSupportTypeApi:
|
||||
def test_get_supported_types(self, app):
|
||||
api = FileSupportTypeApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api)
|
||||
|
||||
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console.human_input_form import (
|
||||
ConsoleHumanInputFormApi,
|
||||
ConsoleWorkflowEventsApi,
|
||||
DifyAPIRepositoryFactory,
|
||||
WorkflowResponseConverter,
|
||||
_jsonify_form_definition,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_jsonify_form_definition() -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
|
||||
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
response = _jsonify_form_definition(form)
|
||||
|
||||
assert isinstance(response, Response)
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["expiration_time"] == int(expiration.timestamp())
|
||||
|
||||
|
||||
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1")
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form)
|
||||
|
||||
|
||||
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
|
||||
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["fields"] == ["a"]
|
||||
|
||||
|
||||
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
def submit_form_by_token(self, **kwargs):
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
assert response.get_json() == {}
|
||||
submit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-2",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
app_id="app-1",
|
||||
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
response_obj = SimpleNamespace(
|
||||
event=SimpleNamespace(value="finished"),
|
||||
model_dump=lambda mode="json": {"status": "done"},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
|
||||
lambda *_args, **_kwargs: app_model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: response_obj,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
response = handler(api, workflow_run_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
assert "data" in response.get_data(as_text=True)
|
||||
108
api/tests/unit_tests/controllers/console/test_init_validate.py
Normal file
108
api/tests/unit_tests/controllers/console/test_init_validate.py
Normal file
@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import init_validate
|
||||
from controllers.console.error import AlreadySetupError, InitValidateFailedError
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, has_setup: bool):
|
||||
self._has_setup = has_setup
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
|
||||
|
||||
|
||||
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "finished"
|
||||
|
||||
|
||||
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "not_started"
|
||||
|
||||
|
||||
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(AlreadySetupError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
|
||||
|
||||
|
||||
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(InitValidateFailedError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
|
||||
assert init_validate.session.get("is_init_validated") is False
|
||||
|
||||
|
||||
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
|
||||
assert result.result == "success"
|
||||
assert init_validate.session.get("is_init_validated") is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session["is_init_validated"] = True
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is False
|
||||
281
api/tests/unit_tests/controllers/console/test_remote_files.py
Normal file
281
api/tests/unit_tests/controllers/console/test_remote_files.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
method: str = "GET",
|
||||
content: bytes = b"",
|
||||
text: str = "",
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.request = SimpleNamespace(method=method)
|
||||
self.content = content
|
||||
self.text = text
|
||||
self._error = error
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
|
||||
def _mock_upload_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
mimetype="text/plain",
|
||||
size=3,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.helpers,
|
||||
"guess_file_info_from_response",
|
||||
MagicMock(return_value=file_info),
|
||||
)
|
||||
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
"get_signed_file_url",
|
||||
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
|
||||
)
|
||||
|
||||
return file_service_cls
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
head_resp = _FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "128"},
|
||||
method="HEAD",
|
||||
)
|
||||
head_mock = MagicMock(return_value=head_resp)
|
||||
get_mock = MagicMock()
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "text/plain", "file_length": 128}
|
||||
head_mock.assert_called_once_with(decoded_url)
|
||||
get_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
|
||||
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
|
||||
get_mock.assert_called_once_with(decoded_url, timeout=3)
|
||||
|
||||
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
|
||||
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="report.txt",
|
||||
size=16,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-1"
|
||||
assert payload["url"] == "https://signed.example/file-1"
|
||||
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
|
||||
file_service_cls.return_value.upload_file.assert_called_once_with(
|
||||
filename="report.txt",
|
||||
content=b"fallback-content",
|
||||
mimetype="text/plain",
|
||||
user=SimpleNamespace(id="u1"),
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
|
||||
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/photo.jpg"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
|
||||
)
|
||||
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="photo.jpg",
|
||||
size=18,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-2"
|
||||
get_mock.assert_called_once_with(url)
|
||||
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"get",
|
||||
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
request = httpx.Request("HEAD", url)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
|
||||
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError, match="size exceeded"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
handler(api)
|
||||
49
api/tests/unit_tests/controllers/console/test_spec.py
Normal file
49
api/tests/unit_tests/controllers/console/test_spec.py
Normal file
@ -0,0 +1,49 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import controllers.console.spec as spec_module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestSpecSchemaDefinitionsApi:
|
||||
def test_get_success(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
schema_definitions = [{"type": "string"}]
|
||||
|
||||
with patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
) as schema_manager_cls:
|
||||
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
|
||||
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == schema_definitions
|
||||
|
||||
def test_get_exception_returns_empty_list(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
spec_module.logger,
|
||||
"exception",
|
||||
) as log_exception,
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == []
|
||||
log_exception.assert_called_once()
|
||||
162
api/tests/unit_tests/controllers/console/test_version.py
Normal file
162
api/tests/unit_tests/controllers/console/test_version.py
Normal file
@ -0,0 +1,162 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.version as version_module
|
||||
|
||||
|
||||
class TestHasNewVersion:
|
||||
def test_has_new_version_true(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.2.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_has_new_version_false(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.0.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_has_new_version_invalid_version(self):
|
||||
with patch.object(version_module.logger, "warning") as log_warning:
|
||||
result = version_module._has_new_version(
|
||||
latest_version="invalid",
|
||||
current_version="1.0.0",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
log_warning.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckVersionUpdate:
|
||||
def test_no_check_update_url(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
True,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
False,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
assert result.can_auto_update is False
|
||||
assert result.features.can_replace_logo is True
|
||||
assert result.features.model_load_balancing_enabled is False
|
||||
|
||||
def test_http_error_fallback(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
version_module.logger,
|
||||
"warning",
|
||||
) as log_warning,
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
log_warning.assert_called_once()
|
||||
|
||||
def test_new_version_available(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.2.0",
|
||||
"releaseDate": "2024-01-01",
|
||||
"releaseNotes": "New features",
|
||||
"canAutoUpdate": True,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
False,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
True,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.release_date == "2024-01-01"
|
||||
assert result.release_notes == "New features"
|
||||
assert result.can_auto_update is True
|
||||
|
||||
def test_no_new_version(self):
|
||||
query = version_module.VersionQuery(current_version="1.2.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.1.0",
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.2.0",
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.can_auto_update is False
|
||||
Loading…
Reference in New Issue
Block a user