fix(api): document message responses as serialized contracts

This commit is contained in:
chariri 2026-06-26 03:29:06 +09:00
parent e937794831
commit 699b708802
No known key found for this signature in database
GPG Key ID: 23A554A36F7FF2FD
20 changed files with 810 additions and 258 deletions

View File

@ -167,12 +167,16 @@ register_schema_models(
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
)
register_response_schema_models(
console_ns,
AnnotationCountResponse,
SuggestedQuestionsResponse,
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
SimpleResultResponse,
TextFileResponse,
)
register_response_schema_models(console_ns, SimpleResultResponse, TextFileResponse)
@console_ns.route("/apps/<uuid:app_id>/chat-messages")

View File

@ -12,6 +12,7 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
from controllers.common.fields import GeneratedAppResponse
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_model,
register_schema_models,
@ -150,12 +151,11 @@ class DatasourcePluginsApi(DatasetApiResource):
if not dataset:
raise NotFound("Dataset not found.")
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
query = query_params_from_request(DatasourcePluginsQuery)
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=query.is_published
)
return datasource_plugins, 200

View File

@ -13435,7 +13435,6 @@ Soft lifecycle state for Agent records.
| created_at | integer | | No |
| files | [ string ] | | Yes |
| id | string | | Yes |
| message_chain_id | string | | No |
| message_id | string | | Yes |
| observation | string | | No |
| position | integer | | Yes |
@ -14567,8 +14566,8 @@ Enum class for configurate method of provider model.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| annotation_create_account | [SimpleAccount](#simpleaccount) | | No |
| annotation_id | string | | Yes |
| created_at | integer | | No |
| id | string | | Yes |
#### ConversationDetail
@ -17072,6 +17071,7 @@ Enum class for large language model mode.
| agent_thoughts | [ [AgentThought](#agentthought) ] | | No |
| annotation | [ConversationAnnotation](#conversationannotation) | | No |
| annotation_hit_history | [ConversationAnnotationHitHistory](#conversationannotationhithistory) | | No |
| answer | string | | Yes |
| answer_tokens | integer | | No |
| conversation_id | string | | Yes |
| created_at | integer | | No |
@ -17085,12 +17085,11 @@ Enum class for large language model mode.
| inputs | object | | Yes |
| message | [JSONValue](#jsonvalue) | | No |
| message_files | [ [MessageFile](#messagefile) ] | | No |
| message_metadata_dict | [JSONValue](#jsonvalue) | | No |
| message_tokens | integer | | No |
| metadata | [JSONValue](#jsonvalue) | | No |
| parent_message_id | string | | No |
| provider_response_latency | number | | No |
| query | string | | Yes |
| re_sign_file_url_answer | string | | Yes |
| status | string | | Yes |
| workflow_run_id | string | | No |

View File

@ -607,7 +607,11 @@ class TestMiscApis:
method = unwrap(api.get)
service = MagicMock()
service.get_recommended_plugins.return_value = [{"id": "p1"}]
recommended_plugins = {
"installed_recommended_plugins": [{"id": "p1"}],
"uninstalled_recommended_plugins": [{"id": "p2"}],
}
service.get_recommended_plugins.return_value = recommended_plugins
user = make_account()
tenant_id = "tenant-1"
@ -619,7 +623,7 @@ class TestMiscApis:
),
):
result = method(api, tenant_id, user)
assert result == [{"id": "p1"}]
assert result == recommended_plugins
service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id)
@ -826,7 +830,7 @@ class TestRagPipelineByIdApi:
result = method(api, pipeline, "old-workflow")
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
assert result == ("", 204)
def test_delete_active_workflow_rejected(self, app: Flask) -> None:
api = RagPipelineByIdApi()

View File

@ -2,21 +2,22 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.web.site import AppSiteApi, AppSiteInfo
from controllers.web.site import AppSiteApi, WebAppSiteResponse, WebModelConfigResponse
from models import Tenant, TenantStatus
from models.model import App, AppMode, CustomizeTokenStrategy, Site
from models.account import TenantCustomConfigDict
from models.model import App, AppMode, AppModelConfig, CustomizeTokenStrategy, EndUser, Site
from services.feature_service import FeatureModel
@pytest.fixture
def app(flask_app_with_containers) -> Flask:
def app(flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
@ -41,7 +42,23 @@ def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True
def _create_site(db_session: Session, app_id: str) -> Site:
site = Site(
site = _site_model(app_id=app_id)
db_session.add(site)
db_session.commit()
return site
def _end_user(tenant_id: str, app_id: str) -> EndUser:
return EndUser(
tenant_id=tenant_id,
app_id=app_id,
type="browser",
session_id=f"session-{app_id}",
)
def _site_model(*, app_id: str) -> Site:
return Site(
app_id=app_id,
title="Site",
icon_type="emoji",
@ -51,31 +68,30 @@ def _create_site(db_session: Session, app_id: str) -> Site:
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
custom_disclaimer="",
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
code=f"code-{app_id[-6:]}",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
db_session.add(site)
db_session.commit()
return site
class TestAppSiteApi:
@patch("controllers.web.site.FeatureService.get_features")
def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None:
def test_happy_path(self, mock_features: MagicMock, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
end_user = _end_user(tenant.id, app_model.id)
mock_features.return_value = FeatureModel(can_replace_logo=False)
with app.test_request_context("/site"):
result = AppSiteApi().get(app_model, end_user)
assert result["app_id"] == app_model.id
assert result["end_user_id"] == end_user.id
assert result["plan"] == "basic"
assert result["enable_site"] is True
@ -83,51 +99,139 @@ class TestAppSiteApi:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
end_user = SimpleNamespace(id="eu-1")
end_user = _end_user(tenant.id, app_model.id)
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
@patch("controllers.web.site.FeatureService.get_features")
def test_archived_tenant_raises_forbidden(
self, mock_features, app: Flask, db_session_with_containers: Session
) -> None:
def test_archived_tenant_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
end_user = _end_user(tenant.id, app_model.id)
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
class TestAppSiteInfo:
def _tenant_model(*, plan: str = "basic", custom_config: TenantCustomConfigDict | None = None) -> Tenant:
tenant = Tenant(name="test-tenant", plan=plan)
tenant.custom_config_dict = custom_config or {}
return tenant
def _app_model(*, tenant: Tenant, enable_site: bool = True) -> App:
app_model = App(
tenant_id=tenant.id,
mode=AppMode.CHAT,
name="test-app",
enable_site=enable_site,
enable_api=True,
)
app_model.id = "app-test"
return app_model
class TestWebAppSiteResponse:
def test_basic_fields(self) -> None:
tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={})
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
assert info.app_id == "app-1"
assert info.end_user_id == "eu-1"
assert info.enable_site is True
assert info.plan == "basic"
assert info.can_replace_logo is False
assert info.model_config is None
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
def test_can_replace_logo_sets_custom_config(self) -> None:
tenant = SimpleNamespace(
id="tenant-1",
plan="pro",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
tenant = _tenant_model()
app_model = _app_model(tenant=tenant)
response = WebAppSiteResponse.from_app_site(
tenant=tenant,
app_model=app_model,
site=_site_model(app_id=app_model.id),
end_user_id="eu-1",
can_replace_logo=False,
)
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
assert info.can_replace_logo is True
assert info.custom_config["remove_webapp_brand"] is True
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
assert response.app_id == app_model.id
assert response.end_user_id == "eu-1"
assert response.enable_site is True
assert response.plan == "basic"
assert response.can_replace_logo is False
assert response.model_config_ is None
assert response.custom_config is None
assert response.site.custom_disclaimer == ""
def test_nullable_site_fields_preserve_none(self) -> None:
tenant = _tenant_model()
app_model = _app_model(tenant=tenant)
site = _site_model(app_id=app_model.id)
site.chat_color_theme = None
site.icon_type = None
site.icon = None
site.icon_background = None
site.description = None
site.copyright = None
site.privacy_policy = None
response = WebAppSiteResponse.from_app_site(
tenant=tenant,
app_model=app_model,
site=site,
end_user_id=None,
can_replace_logo=False,
)
dumped = response.model_dump(mode="json")
assert dumped["end_user_id"] is None
assert dumped["site"]["chat_color_theme"] is None
assert dumped["site"]["icon_type"] is None
assert dumped["site"]["icon"] is None
assert dumped["site"]["icon_background"] is None
assert dumped["site"]["description"] is None
assert dumped["site"]["copyright"] is None
assert dumped["site"]["privacy_policy"] is None
assert dumped["site"]["custom_disclaimer"] == ""
@patch("controllers.web.site.dify_config.FILES_URL", "https://files.example.com")
def test_can_replace_logo_sets_custom_config(self) -> None:
tenant = _tenant_model(
plan="pro",
custom_config={"remove_webapp_brand": True, "replace_webapp_logo": "enabled"},
)
app_model = _app_model(tenant=tenant)
response = WebAppSiteResponse.from_app_site(
tenant=tenant,
app_model=app_model,
site=_site_model(app_id=app_model.id),
end_user_id="eu-1",
can_replace_logo=True,
)
assert response.can_replace_logo is True
assert response.custom_config is not None
assert response.custom_config.remove_webapp_brand is True
assert response.custom_config.replace_webapp_logo is not None
assert "webapp-logo" in response.custom_config.replace_webapp_logo
class TestWebModelConfigResponse:
def test_serializes_internal_model_config_properties_to_public_keys(self) -> None:
model_config = AppModelConfig(
app_id="app-test",
opening_statement="Hello",
suggested_questions='["Question?"]',
suggested_questions_after_answer='{"enabled": true}',
more_like_this='{"enabled": false}',
model='{"provider": "openai", "name": "gpt-4o", "mode": "chat"}',
user_input_form='[{"text-input": {"label": "Name", "variable": "name", "required": true}}]',
pre_prompt="System prompt",
created_by="account-1",
updated_by="account-1",
)
dumped = WebModelConfigResponse.model_validate(model_config, from_attributes=True).model_dump(mode="json")
assert dumped == {
"opening_statement": "Hello",
"suggested_questions": ["Question?"],
"suggested_questions_after_answer": {"enabled": True},
"more_like_this": {"enabled": False},
"model": {"provider": "openai", "name": "gpt-4o", "mode": "chat"},
"user_input_form": [{"text-input": {"label": "Name", "variable": "name", "required": True}}],
"pre_prompt": "System prompt",
}

View File

@ -77,9 +77,11 @@ def test_human_input_preview_delegates_to_service(
preview_payload = {
"form_id": "node-42",
"node_id": "node-42",
"node_title": "Human Input",
"form_content": "<div>example</div>",
"inputs": [{"name": "topic"}],
"actions": [{"id": "continue"}],
"actions": [{"id": "continue", "title": "Continue"}],
}
service_instance = MagicMock()
service_instance.get_human_input_form_preview.return_value = preview_payload
@ -88,7 +90,15 @@ def test_human_input_preview_delegates_to_service(
with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-42")
assert response == preview_payload
assert response == {
**preview_payload,
"TYPE": "human_input_required",
"actions": [{"id": "continue", "title": "Continue", "button_style": "default"}],
"resolved_default_values": {},
"display_in_ui": False,
"form_token": None,
"expiration_time": None,
}
service_instance.get_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,

View File

@ -1,4 +1,6 @@
import inspect
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -16,6 +18,7 @@ from controllers.console.datasets.external import (
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
from extensions.ext_database import db
from models.account import Account, TenantAccountRole
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
@ -38,28 +41,183 @@ def current_user() -> Account:
return user
def _external_api_dict(api_id: str = "api-1") -> dict:
return {
"id": api_id,
"tenant_id": "tenant-1",
"name": f"External API {api_id}",
"description": f"Description for {api_id}",
"settings": {
"endpoint": f"https://external.example.com/{api_id}",
"api_key": "secret",
"headers": {"X-Source": "unit-test"},
"timeout": 30,
},
"dataset_bindings": [
{"id": f"dataset-{api_id}", "name": f"Dataset {api_id}"},
],
"created_by": "user-1",
"created_at": "2024-01-01T00:00:00",
}
def _external_api_object(api_id: str = "api-1") -> SimpleNamespace:
payload = _external_api_dict(api_id)
return SimpleNamespace(
**{
**payload,
"dataset_bindings": [SimpleNamespace(**binding) for binding in payload["dataset_bindings"]],
}
)
def _expected_dataset_detail_payload() -> dict[str, Any]:
return {
"id": "dataset-1",
"name": "Support knowledge",
"description": "External support articles",
"provider": "external",
"permission": "only_me",
"data_source_type": "external",
"indexing_technique": "economy",
"app_count": 2,
"document_count": 7,
"word_count": 2048,
"created_by": "user-1",
"author_name": "Test User",
"created_at": 1710000000,
"updated_by": "user-2",
"updated_at": 1710003600,
"embedding_model": None,
"embedding_model_provider": None,
"embedding_available": False,
"retrieval_model_dict": {
"search_method": "semantic_search",
"reranking_enable": False,
"reranking_mode": None,
"reranking_model": {"reranking_provider_name": None, "reranking_model_name": None},
"weights": None,
"top_k": 4,
"score_threshold_enabled": True,
"score_threshold": 0.5,
},
"summary_index_setting": {
"enable": True,
"model_name": "summary-model",
"model_provider_name": "provider-a",
"summary_prompt": "Summarize this.",
},
"tags": [{"id": "tag-1", "name": "Support", "type": "knowledge"}],
"doc_form": "text_model",
"external_knowledge_info": {
"external_knowledge_id": "knowledge-1",
"external_knowledge_api_id": "api-1",
"external_knowledge_api_name": "External API api-1",
"external_knowledge_api_endpoint": "https://external.example.com/api-1",
},
"external_retrieval_model": {
"top_k": 4,
"score_threshold": 0.5,
"score_threshold_enabled": True,
},
"doc_metadata": [{"id": "metadata-1", "name": "source", "type": "string"}],
"built_in_field_enabled": True,
"pipeline_id": None,
"runtime_mode": "external",
"chunk_structure": "general",
"icon_info": {
"icon_type": "emoji",
"icon": "book",
"icon_background": "#FFF4ED",
"icon_url": None,
},
"is_published": True,
"total_documents": 7,
"total_available_documents": 6,
"enable_api": True,
"is_multimodal": False,
"maintainer": None,
"permission_keys": [],
}
def _dataset_detail_object() -> SimpleNamespace:
payload = _expected_dataset_detail_payload()
return SimpleNamespace(
**{
**payload,
"summary_index_setting": SimpleNamespace(**payload["summary_index_setting"]),
"tags": [SimpleNamespace(**tag) for tag in payload["tags"]],
"external_knowledge_info": SimpleNamespace(**payload["external_knowledge_info"]),
"external_retrieval_model": SimpleNamespace(**payload["external_retrieval_model"]),
"doc_metadata": [SimpleNamespace(**item) for item in payload["doc_metadata"]],
"icon_info": SimpleNamespace(**payload["icon_info"]),
}
)
class TestExternalApiTemplateListApi:
def test_get_success(self, app: Flask):
api = ExternalApiTemplateListApi()
method = inspect.unwrap(api.get)
api_item = MagicMock()
api_item.to_dict.return_value = {"id": "1"}
api_item = _external_api_object("api-1")
with (
app.test_request_context("/?page=1&limit=20"),
app.test_request_context("/?page=2&limit=1&keyword=vector"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
return_value=([api_item], 3),
) as get_external_knowledge_apis,
):
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
assert resp == {
"data": [_external_api_dict("api-1")],
"has_more": True,
"limit": 1,
"total": 3,
"page": 2,
}
get_external_knowledge_apis.assert_called_once_with(2, 1, "tenant-1", "vector")
def test_post_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account):
api = ExternalApiTemplateListApi()
method = inspect.unwrap(api.post)
payload = {
"name": "Vendor Search",
"settings": {
"endpoint": "https://external.example.com/search",
"api_key": "secret",
"headers": {"X-Source": "unit-test"},
"timeout": 30,
},
}
created = _external_api_object("api-created")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list,
patch.object(
ExternalDatasetService,
"create_external_knowledge_api",
return_value=created,
) as create_external_knowledge_api,
):
resp, status = method(api, "tenant-1", current_user)
assert status == 201
assert resp == _external_api_dict("api-created")
validate_api_list.assert_called_once_with(payload["settings"])
create_external_knowledge_api.assert_called_once_with(
tenant_id="tenant-1",
user_id="user-1",
args=payload,
)
def test_post_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
@ -97,6 +255,25 @@ class TestExternalApiTemplateListApi:
class TestExternalApiTemplateApi:
def test_get_success_returns_template_contract(self, app: Flask):
api = ExternalApiTemplateApi()
method = inspect.unwrap(api.get)
template = _external_api_object("api-detail")
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_api",
return_value=template,
) as get_external_knowledge_api,
):
resp, status = method(api, "tenant-1", "api-detail")
assert status == 200
assert resp == _external_api_dict("api-detail")
get_external_knowledge_api.assert_called_once_with("api-detail", "tenant-1")
def test_get_not_found(self, app: Flask):
api = ExternalApiTemplateApi()
method = inspect.unwrap(api.get)
@ -112,6 +289,42 @@ class TestExternalApiTemplateApi:
with pytest.raises(NotFound):
method(api, "tenant-1", "api-id")
def test_patch_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account):
api = ExternalApiTemplateApi()
method = inspect.unwrap(api.patch)
payload = {
"name": "Updated API",
"settings": {
"endpoint": "https://external.example.com/updated",
"api_key": "new-secret",
"headers": {"X-Version": "2"},
},
}
updated = _external_api_object("api-updated")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list,
patch.object(
ExternalDatasetService,
"update_external_knowledge_api",
return_value=updated,
) as update_external_knowledge_api,
):
resp, status = method(api, "tenant-1", current_user, "api-updated")
assert status == 200
assert resp == _external_api_dict("api-updated")
validate_api_list.assert_called_once_with(payload["settings"])
update_external_knowledge_api.assert_called_once_with(
tenant_id="tenant-1",
user_id="user-1",
external_knowledge_api_id="api-updated",
args=payload,
)
def test_delete_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
@ -149,45 +362,37 @@ class TestExternalDatasetCreateApi:
method = inspect.unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"name": "Support knowledge",
"description": "External support articles",
"external_retrieval_model": {
"top_k": 4,
"score_threshold": 0.5,
"score_threshold_enabled": True,
},
}
dataset = MagicMock()
dataset.embedding_available = False
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.enable_qa = False
dataset.enable_vector_store = False
dataset.vector_store_setting = None
dataset.is_multimodal = False
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.permission_keys = []
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
dataset = _dataset_detail_object()
with (
app.test_request_context("/"),
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetService,
"create_external_dataset",
return_value=dataset,
),
) as create_external_dataset,
):
_, status = method(api, "tenant-1", current_user)
resp, status = method(api, "tenant-1", current_user)
assert status == 201
assert resp == _expected_dataset_detail_payload()
create_external_dataset.assert_called_once_with(
tenant_id="tenant-1",
user_id="user-1",
args=payload,
)
def test_create_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
@ -228,24 +433,58 @@ class TestExternalKnowledgeHitTestingApi:
api = ExternalKnowledgeHitTestingApi()
method = inspect.unwrap(api.post)
payload = {"query": "hello"}
payload = {
"query": "hello",
"external_retrieval_model": {
"top_k": 3,
"score_threshold": 0.25,
"score_threshold_enabled": True,
},
"metadata_filtering_conditions": {
"logical_operator": "and",
"conditions": [{"name": "source", "comparison_operator": "contains", "value": "external"}],
},
}
dataset = MagicMock()
retrieve_response = {
"query": {"content": "hello"},
"records": [
{
"content": "answer",
"title": "doc",
"score": 0.9,
"metadata": {"source": "external", "page": 2},
}
],
}
with (
app.test_request_context("/"),
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission"),
patch.object(DatasetService, "check_dataset_permission") as check_dataset_permission,
patch.object(HitTestingService, "hit_testing_args_check") as hit_testing_args_check,
patch.object(
HitTestingService,
"external_retrieve",
return_value={"ok": True},
),
return_value=retrieve_response,
) as external_retrieve,
patch("controllers.console.datasets.external.dump_response", side_effect=lambda _model, value: value),
):
resp = method(api, current_user, "dataset-id")
assert resp["ok"] is True
assert resp == retrieve_response
check_dataset_permission.assert_called_once_with(dataset, current_user)
hit_testing_args_check.assert_called_once_with(payload)
external_retrieve.assert_called_once_with(
session=db.session,
dataset=dataset,
query="hello",
account=current_user,
external_retrieval_model=payload["external_retrieval_model"],
metadata_filtering_conditions=payload["metadata_filtering_conditions"],
)
class TestBedrockRetrievalApi:
@ -254,24 +493,44 @@ class TestBedrockRetrievalApi:
method = inspect.unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "hello",
"knowledge_id": "kid",
"retrieval_setting": {"top_k": 5, "score_threshold": 0.72},
"query": "hello bedrock",
"knowledge_id": "knowledge-base-1",
}
retrieval_response = {
"records": [
{
"metadata": {"source": "bedrock", "uri": "s3://bucket/doc.txt"},
"score": 0.8,
"title": "doc",
"content": "answer",
},
{
"metadata": {"source": "bedrock", "uri": "s3://bucket/other.txt"},
"score": 0.65,
"title": None,
"content": None,
},
]
}
with (
app.test_request_context("/"),
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetTestService,
"knowledge_retrieval",
return_value={"ok": True},
),
return_value=retrieval_response,
) as knowledge_retrieval,
):
resp, status = method()
assert status == 200
assert resp["ok"] is True
assert resp == retrieval_response
retrieval_setting, query, knowledge_id = knowledge_retrieval.call_args.args
assert retrieval_setting.model_dump() == payload["retrieval_setting"]
assert query == "hello bedrock"
assert knowledge_id == "knowledge-base-1"
class TestExternalApiTemplateListApiAdvanced:
@ -297,10 +556,10 @@ class TestExternalApiTemplateListApiAdvanced:
api = ExternalApiTemplateListApi()
method = inspect.unwrap(api.get)
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
templates = [_external_api_object(f"api-{i}") for i in range(3)]
with (
app.test_request_context("/?page=1&limit=20"),
app.test_request_context("/?page=2&limit=3"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
@ -309,9 +568,14 @@ class TestExternalApiTemplateListApiAdvanced:
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
assert resp == {
"data": [_external_api_dict(f"api-{i}") for i in range(3)],
"has_more": True,
"limit": 3,
"total": 25,
"page": 2,
}
get_external_knowledge_apis.assert_called_once_with(2, 3, "tenant-1", None)
class TestExternalDatasetCreateApiAdvanced:
@ -374,15 +638,46 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission") as check_permission,
patch("controllers.console.datasets.external.HitTestingService.hit_testing_args_check") as args_check,
patch(
"controllers.console.datasets.external.HitTestingService.external_retrieve",
return_value={"results": []},
),
return_value={
"query": {"content": "test query"},
"records": [
{
"content": None,
"title": "metadata-only",
"score": None,
"metadata": {"status": "active"},
}
],
},
) as external_retrieve,
):
resp = method(api, current_user, "ds-1")
assert resp["results"] == []
assert resp == {
"query": {"content": "test query"},
"records": [
{
"content": None,
"title": "metadata-only",
"score": None,
"metadata": {"status": "active"},
}
],
}
check_permission.assert_called_once_with(dataset, current_user)
args_check.assert_called_once_with(payload)
external_retrieve.assert_called_once_with(
session=db.session,
dataset=dataset,
query="test query",
account=current_user,
external_retrieval_model={"type": "bm25"},
metadata_filtering_conditions={"status": "active"},
)
class TestBedrockRetrievalApiAdvanced:

View File

@ -1,4 +1,4 @@
from inspect import unwrap as inspect_unwrap
from inspect import unwrap
from io import BytesIO
from typing import Any
from unittest.mock import MagicMock, patch
@ -35,24 +35,16 @@ from models.model import AppMode
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
unwrap: Any = inspect_unwrap
@pytest.fixture
def account() -> Account:
acc = Account(name="User", email="user@example.com")
def account():
acc = MagicMock(spec=Account)
acc.id = "u1"
return acc
def _file_data() -> Any:
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
return file_data
@pytest.fixture
def trial_app_chat() -> MagicMock:
def trial_app_chat():
app = MagicMock()
app.id = "a-chat"
app.mode = AppMode.CHAT
@ -60,7 +52,7 @@ def trial_app_chat() -> MagicMock:
@pytest.fixture
def trial_app_completion() -> MagicMock:
def trial_app_completion():
app = MagicMock()
app.id = "a-comp"
app.mode = AppMode.COMPLETION
@ -68,7 +60,7 @@ def trial_app_completion() -> MagicMock:
@pytest.fixture
def trial_app_workflow() -> MagicMock:
def trial_app_workflow():
app = MagicMock()
app.id = "a-workflow"
app.mode = AppMode.WORKFLOW
@ -76,7 +68,7 @@ def trial_app_workflow() -> MagicMock:
@pytest.fixture
def valid_parameters() -> dict[str, object]:
def valid_parameters():
return {
"user_input_form": [],
"system_parameters": {},
@ -92,13 +84,54 @@ def valid_parameters() -> dict[str, object]:
}
def test_trial_workflow_uses_trial_scoped_simple_account_model() -> None:
assert module.simple_account_model.name == "TrialSimpleAccount"
assert hasattr(module.simple_account_model, "items")
def test_trial_workflow_registers_normalized_simple_account_response_model():
assert "SimpleAccountResponse" in module.console_ns.models
def _response_model_name(entry: object) -> str:
assert isinstance(entry, tuple)
assert len(entry) >= 2
model = entry[1]
name = getattr(model, "name", None)
assert isinstance(name, str)
return name
def test_trial_endpoints_keep_response_and_query_docs():
untyped_generated_response_views = [
module.TrialAppWorkflowRunApi.post,
module.TrialChatApi.post,
module.TrialCompletionApi.post,
]
for view in untyped_generated_response_views:
apidoc = getattr(view, "__apidoc__", {})
assert apidoc.get("responses", {})["200"] == ("Success", None, {})
cases = [
(module.TrialMessageSuggestedQuestionApi.get, module.SuggestedQuestionsResponse.__name__),
(module.TrialChatAudioApi.post, module.AudioTranscriptResponse.__name__),
(module.TrialChatTextApi.post, module.AudioBinaryResponse.__name__),
(module.TrialSitApi.get, module.SiteResponse.__name__),
(module.TrialAppParameterApi.get, module.ParametersResponse.__name__),
(module.AppApi.get, module.AppDetailWithSite.__name__),
(module.AppWorkflowApi.get, module.WorkflowResponse.__name__),
(module.DatasetListApi.get, module.TrialDatasetListResponse.__name__),
]
for view, model_name in cases:
apidoc = getattr(view, "__apidoc__", {})
responses = apidoc.get("responses", {})
assert _response_model_name(responses["200"]) == model_name
dataset_params = module.DatasetListApi.get.__apidoc__["params"]
assert dataset_params["ids"]["in"] == "query"
assert dataset_params["ids"]["type"] == "array"
assert dataset_params["page"]["default"] == 1
assert dataset_params["limit"]["default"] == 20
class TestTrialAppWorkflowRunApi:
def test_not_workflow_app(self, app: Flask, account: Account) -> None:
def test_not_workflow_app(self, app: Flask, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -106,7 +139,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(NotWorkflowAppError):
method(api, account, MagicMock(mode=AppMode.CHAT))
def test_success(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -119,7 +152,7 @@ class TestTrialAppWorkflowRunApi:
assert result is not None
def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -134,7 +167,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_workflow)
def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -149,7 +182,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(ProviderQuotaExceededError):
method(api, account, trial_app_workflow)
def test_workflow_model_not_support(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_model_not_support(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -164,7 +197,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(api, account, trial_app_workflow)
def test_workflow_invoke_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_invoke_error(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -179,7 +212,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(CompletionRequestError):
method(api, account, trial_app_workflow)
def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -194,7 +227,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(InvokeRateLimitHttpError):
method(api, account, trial_app_workflow)
def test_workflow_value_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_value_error(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -209,7 +242,7 @@ class TestTrialAppWorkflowRunApi:
with pytest.raises(ValueError):
method(api, account, trial_app_workflow)
def test_workflow_generic_exception(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
def test_workflow_generic_exception(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -226,7 +259,7 @@ class TestTrialAppWorkflowRunApi:
class TestTrialChatApi:
def test_not_chat_app(self, app: Flask, account: Account) -> None:
def test_not_chat_app(self, app: Flask, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -234,7 +267,7 @@ class TestTrialChatApi:
with pytest.raises(NotChatAppError):
method(api, account, MagicMock(mode="completion"))
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -247,7 +280,7 @@ class TestTrialChatApi:
assert result is not None
def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -262,7 +295,7 @@ class TestTrialChatApi:
with pytest.raises(NotFound):
method(api, account, trial_app_chat)
def test_chat_conversation_completed(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_conversation_completed(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -277,7 +310,7 @@ class TestTrialChatApi:
with pytest.raises(ConversationCompletedError):
method(api, account, trial_app_chat)
def test_chat_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_app_config_broken(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -292,7 +325,7 @@ class TestTrialChatApi:
with pytest.raises(AppUnavailableError):
method(api, account, trial_app_chat)
def test_chat_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_provider_not_init(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -307,7 +340,7 @@ class TestTrialChatApi:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_chat)
def test_chat_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_quota_exceeded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -322,7 +355,7 @@ class TestTrialChatApi:
with pytest.raises(ProviderQuotaExceededError):
method(api, account, trial_app_chat)
def test_chat_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_model_not_support(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -337,7 +370,7 @@ class TestTrialChatApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(api, account, trial_app_chat)
def test_chat_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_invoke_error(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -352,7 +385,7 @@ class TestTrialChatApi:
with pytest.raises(CompletionRequestError):
method(api, account, trial_app_chat)
def test_chat_rate_limit_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_rate_limit_error(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -367,7 +400,7 @@ class TestTrialChatApi:
with pytest.raises(InvokeRateLimitHttpError):
method(api, account, trial_app_chat)
def test_chat_value_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_value_error(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -382,7 +415,7 @@ class TestTrialChatApi:
with pytest.raises(ValueError):
method(api, account, trial_app_chat)
def test_chat_generic_exception(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_chat_generic_exception(self, app: Flask, trial_app_chat, account):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -399,7 +432,7 @@ class TestTrialChatApi:
class TestTrialCompletionApi:
def test_not_completion_app(self, app: Flask, account: Account) -> None:
def test_not_completion_app(self, app: Flask, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -407,7 +440,7 @@ class TestTrialCompletionApi:
with pytest.raises(NotCompletionAppError):
method(api, account, MagicMock(mode=AppMode.CHAT))
def test_success(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -420,7 +453,7 @@ class TestTrialCompletionApi:
assert result is not None
def test_completion_app_config_broken(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_app_config_broken(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -435,7 +468,7 @@ class TestTrialCompletionApi:
with pytest.raises(AppUnavailableError):
method(api, account, trial_app_completion)
def test_completion_provider_not_init(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_provider_not_init(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -450,7 +483,7 @@ class TestTrialCompletionApi:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_completion)
def test_completion_quota_exceeded(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_quota_exceeded(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -465,7 +498,7 @@ class TestTrialCompletionApi:
with pytest.raises(ProviderQuotaExceededError):
method(api, account, trial_app_completion)
def test_completion_model_not_support(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_model_not_support(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -480,7 +513,7 @@ class TestTrialCompletionApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(api, account, trial_app_completion)
def test_completion_invoke_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_invoke_error(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -495,7 +528,7 @@ class TestTrialCompletionApi:
with pytest.raises(CompletionRequestError):
method(api, account, trial_app_completion)
def test_completion_rate_limit_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_rate_limit_error(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -510,7 +543,7 @@ class TestTrialCompletionApi:
with pytest.raises(InternalServerError):
method(api, account, trial_app_completion)
def test_completion_value_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_value_error(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -525,7 +558,7 @@ class TestTrialCompletionApi:
with pytest.raises(ValueError):
method(api, account, trial_app_completion)
def test_completion_generic_exception(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
def test_completion_generic_exception(self, app: Flask, trial_app_completion, account):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -542,7 +575,7 @@ class TestTrialCompletionApi:
class TestTrialMessageSuggestedQuestionApi:
def test_not_chat_app(self, app: Flask, account: Account) -> None:
def test_not_chat_app(self, app: Flask, account):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -550,7 +583,7 @@ class TestTrialMessageSuggestedQuestionApi:
with pytest.raises(NotChatAppError):
method(api, account, MagicMock(mode="completion"), str(uuid4()))
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_chat, account):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -566,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi:
assert result == {"data": ["q1", "q2"]}
def test_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_conversation_not_exists(self, app: Flask, trial_app_chat, account):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -583,14 +616,14 @@ class TestTrialMessageSuggestedQuestionApi:
class TestTrialAppParameterApi:
def test_app_unavailable(self) -> None:
def test_app_unavailable(self):
api = module.TrialAppParameterApi()
method = unwrap(api.get)
with pytest.raises(AppUnavailableError):
method(api, None)
def test_success_non_workflow(self, valid_parameters: dict[str, object]) -> None:
def test_success_non_workflow(self, valid_parameters):
api = module.TrialAppParameterApi()
method = unwrap(api.get)
@ -617,11 +650,12 @@ class TestTrialAppParameterApi:
class TestTrialChatAudioApi:
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -634,11 +668,12 @@ class TestTrialChatAudioApi:
assert result == {"text": "hello"}
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -653,11 +688,12 @@ class TestTrialChatAudioApi:
with pytest.raises(module.AppUnavailableError):
method(api, account, trial_app_chat)
def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -672,11 +708,12 @@ class TestTrialChatAudioApi:
with pytest.raises(module.NoAudioUploadedError):
method(api, account, trial_app_chat)
def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_audio_too_large(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -691,11 +728,12 @@ class TestTrialChatAudioApi:
with pytest.raises(module.AudioTooLargeError):
method(api, account, trial_app_chat)
def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -710,11 +748,12 @@ class TestTrialChatAudioApi:
with pytest.raises(module.UnsupportedAudioTypeError):
method(api, account, trial_app_chat)
def test_provider_not_support_tts(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_provider_not_support_tts(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -729,11 +768,12 @@ class TestTrialChatAudioApi:
with pytest.raises(module.ProviderNotSupportSpeechToTextError):
method(api, account, trial_app_chat)
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -744,11 +784,12 @@ class TestTrialChatAudioApi:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_chat)
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -761,7 +802,7 @@ class TestTrialChatAudioApi:
class TestTrialChatTextApi:
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_success(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -774,7 +815,7 @@ class TestTrialChatTextApi:
assert result == {"audio": "base64_data"}
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -789,7 +830,7 @@ class TestTrialChatTextApi:
with pytest.raises(module.AppUnavailableError):
method(api, account, trial_app_chat)
def test_provider_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_provider_not_support(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -804,7 +845,7 @@ class TestTrialChatTextApi:
with pytest.raises(module.ProviderNotSupportSpeechToTextError):
method(api, account, trial_app_chat)
def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_audio_too_large(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -819,7 +860,7 @@ class TestTrialChatTextApi:
with pytest.raises(module.AudioTooLargeError):
method(api, account, trial_app_chat)
def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -834,7 +875,7 @@ class TestTrialChatTextApi:
with pytest.raises(module.NoAudioUploadedError):
method(api, account, trial_app_chat)
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -845,7 +886,7 @@ class TestTrialChatTextApi:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_chat)
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -856,7 +897,7 @@ class TestTrialChatTextApi:
with pytest.raises(ProviderQuotaExceededError):
method(api, account, trial_app_chat)
def test_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_model_not_support(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -867,7 +908,7 @@ class TestTrialChatTextApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(api, account, trial_app_chat)
def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_invoke_error(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -880,7 +921,7 @@ class TestTrialChatTextApi:
class TestTrialAppWorkflowTaskStopApi:
def test_not_workflow_app(self, app: Flask, trial_app_chat: MagicMock) -> None:
def test_not_workflow_app(self, app: Flask, trial_app_chat):
api = module.TrialAppWorkflowTaskStopApi()
method = unwrap(api.post)
@ -888,7 +929,7 @@ class TestTrialAppWorkflowTaskStopApi:
with pytest.raises(NotWorkflowAppError):
method(api, trial_app_chat, str(uuid4()))
def test_success(self, app: Flask, trial_app_workflow: MagicMock) -> None:
def test_success(self, app: Flask, trial_app_workflow, account):
api = module.TrialAppWorkflowTaskStopApi()
method = unwrap(api.post)
@ -906,7 +947,7 @@ class TestTrialAppWorkflowTaskStopApi:
class TestTrialSitApi:
def test_no_site(self, app: Flask) -> None:
def test_no_site(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
app_model = MagicMock()
@ -917,7 +958,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_archived_tenant(self, app: Flask) -> None:
def test_archived_tenant(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
@ -932,7 +973,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_success(self, app: Flask) -> None:
def test_success(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
@ -957,11 +998,12 @@ class TestTrialSitApi:
class TestTrialChatAudioApiExceptionHandlers:
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -976,11 +1018,12 @@ class TestTrialChatAudioApiExceptionHandlers:
with pytest.raises(ProviderNotInitializeError):
method(api, account, trial_app_chat)
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -995,11 +1038,12 @@ class TestTrialChatAudioApiExceptionHandlers:
with pytest.raises(ProviderQuotaExceededError):
method(api, account, trial_app_chat)
def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_invoke_error(self, app: Flask, trial_app_chat, account):
api = module.TrialChatAudioApi()
method = unwrap(api.post)
file_data = _file_data()
file_data: Any = BytesIO(b"fake audio data")
file_data.filename = "test.wav"
with (
app.test_request_context(
@ -1016,7 +1060,7 @@ class TestTrialChatAudioApiExceptionHandlers:
class TestTrialChatTextApiExceptionHandlers:
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)
@ -1031,7 +1075,7 @@ class TestTrialChatTextApiExceptionHandlers:
with pytest.raises(module.AppUnavailableError):
method(api, account, trial_app_chat)
def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account):
api = module.TrialChatTextApi()
method = unwrap(api.post)

View File

@ -1,3 +1,4 @@
from datetime import datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import ANY, Mock
@ -46,12 +47,73 @@ def _snippet(**overrides) -> CustomizedSnippet:
"name": "Snippet",
"description": "Description",
"type": snippets_module.SnippetType.NODE,
"version": 3,
"use_count": 7,
"is_published": True,
"icon_info": {"icon": "star", "icon_background": "#101828", "icon_type": "emoji"},
"input_fields": '[{"label": "Question", "variable": "query", "type": "text-input"}]',
"created_by": "account-1",
"created_at": datetime(2024, 1, 2, 3, 4, 5),
"updated_by": None,
"updated_at": datetime(2024, 1, 3, 4, 5, 6),
}
data.update(overrides)
return CustomizedSnippet(**data)
def _patch_snippet_response_properties(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
CustomizedSnippet,
"tags",
property(lambda _snippet: [{"id": "tag-1", "name": "Search", "type": "snippet"}]),
)
monkeypatch.setattr(CustomizedSnippet, "created_by_account", property(lambda _snippet: _account("account-1")))
monkeypatch.setattr(CustomizedSnippet, "updated_by_account", property(lambda _snippet: None))
def _expected_snippet_list_item(snippet: CustomizedSnippet) -> dict:
return {
"id": snippet.id,
"name": snippet.name,
"description": snippet.description,
"type": snippet.type,
"version": snippet.version,
"use_count": snippet.use_count,
"is_published": snippet.is_published,
"icon_info": snippet.icon_info,
"tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}],
"created_by": snippet.created_by,
"author_name": "Test User",
"created_at": int(snippet.created_at.timestamp()),
"updated_by": snippet.updated_by,
"updated_at": int(snippet.updated_at.timestamp()),
}
def _expected_snippet_response(snippet: CustomizedSnippet) -> dict:
return {
"id": snippet.id,
"name": snippet.name,
"description": snippet.description,
"type": snippet.type,
"version": snippet.version,
"use_count": snippet.use_count,
"is_published": snippet.is_published,
"icon_info": snippet.icon_info,
"graph": {},
"input_fields": [{"label": "Question", "variable": "query", "type": "text-input"}],
"tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}],
"created_by": {
"id": "account-1",
"name": "Test User",
"email": "account-1@example.com",
},
"created_at": int(snippet.created_at.timestamp()),
"updated_by": None,
"updated_at": int(snippet.updated_at.timestamp()),
}
def test_normalize_snippet_list_query_args_sorts_indexed_values():
query_args = snippets_module.MultiDict(
[
@ -75,7 +137,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey
tag_id = "11111111-1111-1111-1111-111111111111"
get_snippets = Mock(return_value=(snippets, 1, False))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets)
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}]))
_patch_snippet_response_properties(monkeypatch)
api = snippets_module.CustomizedSnippetsApi()
handler = unwrap(api.get)
@ -87,7 +149,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey
assert status_code == 200
assert response == {
"data": [{"id": "snippet-1"}],
"data": [_expected_snippet_list_item(snippets[0])],
"page": 2,
"limit": 10,
"total": 1,
@ -110,6 +172,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
snippet = _snippet()
create_snippet = Mock(return_value=snippet)
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
_patch_snippet_response_properties(monkeypatch)
monkeypatch.setattr(
snippets_module.CreateSnippetPayload,
"model_validate",
@ -124,7 +187,6 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
)
),
)
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
api = snippets_module.CustomizedSnippetsApi()
handler = unwrap(api.post)
@ -137,7 +199,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
response, status_code = handler(api, "tenant-1", user)
assert status_code == 201
assert response == {"id": "snippet-1"}
assert response == _expected_snippet_response(snippet)
assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE
@ -184,7 +246,7 @@ def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest.
def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch):
snippet = _snippet()
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
_patch_snippet_response_properties(monkeypatch)
api = snippets_module.CustomizedSnippetDetailApi()
handler = unwrap(api.get)
@ -193,7 +255,7 @@ def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.Monk
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
assert status_code == 200
assert response == {"id": "snippet-1"}
assert response == _expected_snippet_response(snippet)
def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch):
@ -230,7 +292,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke
monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet)
monkeypatch.setattr(snippets_module, "Session", SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"}))
_patch_snippet_response_properties(monkeypatch)
api = snippets_module.CustomizedSnippetDetailApi()
handler = unwrap(api.patch)
@ -243,7 +305,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
assert status_code == 200
assert response == {"id": "snippet-1", "name": "New"}
assert response == _expected_snippet_response(updated_snippet)
update_snippet.assert_called_once()
assert update_snippet.call_args.kwargs["data"] == {
"name": "New",

View File

@ -26,7 +26,9 @@ from controllers.console.workspace.workspace import (
WebappLogoWorkspaceApi,
WorkspaceInfoApi,
WorkspaceListApi,
WorkspaceLogoUploadResponse,
WorkspacePermissionApi,
WorkspacePermissionResponse,
)
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -587,7 +589,8 @@ class TestWebappLogoWorkspaceApi:
result, status = method(api, user)
assert status == 201
assert result["id"] == "file1"
assert result == {"id": "file1"}
assert WorkspaceLogoUploadResponse.model_validate(result).model_dump(mode="json") == {"id": "file1"}
def test_filename_missing(self, app: Flask):
api = WebappLogoWorkspaceApi()
@ -676,7 +679,7 @@ class TestWorkspaceInfoApi:
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"name": "New Name"},
return_value={"id": "t1", "name": "New Name"},
),
):
result = method(api, "t1")
@ -717,7 +720,13 @@ class TestWorkspacePermissionApi:
result, status = method(api, "t1")
assert status == 200
assert result["workspace_id"] == "t1"
expected = {
"workspace_id": "t1",
"allow_member_invite": True,
"allow_owner_transfer": False,
}
assert result == expected
assert WorkspacePermissionResponse.model_validate(result).model_dump(mode="json") == expected
def test_no_current_tenant(self, app: Flask):
api = WorkspacePermissionApi()

View File

@ -325,10 +325,12 @@ class TestPipelineRunApiEntity:
def test_entity_missing_required_field(self):
"""Test entity raises on missing required field."""
with pytest.raises(ValueError):
PipelineRunApiEntity(
inputs={},
datasource_type="online_document",
# missing datasource_info_list, start_node_id, etc.
PipelineRunApiEntity.model_validate(
{
"inputs": {},
"datasource_type": "online_document",
# missing datasource_info_list, start_node_id, etc.
}
)
@ -382,8 +384,19 @@ class TestDatasourcePluginsApiGet:
mock_dataset = Mock()
mock_db.session.scalar.return_value = mock_dataset
datasource_plugins = [
{
"node_id": "node-datasource-1",
"plugin_id": "plugin-a",
"provider_name": "provider-a",
"datasource_type": "online_document",
"title": "Online Docs",
"user_input_variables": [{"variable": "url", "label": "URL", "type": "text-input", "required": True}],
"credentials": [{"id": "cred-1", "name": "Default credential", "type": "oauth2", "is_default": True}],
}
]
mock_svc_instance = Mock()
mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}]
mock_svc_instance.get_datasource_plugins.return_value = datasource_plugins
mock_svc_cls.return_value = mock_svc_instance
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"):
@ -391,11 +404,33 @@ class TestDatasourcePluginsApiGet:
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
assert status == 200
assert response == [{"name": "plugin_a"}]
assert response == datasource_plugins
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=True
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_parses_false_is_published_query(self, mock_svc_cls, mock_db, app: Flask):
"""Test false query string is parsed as boolean False."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
mock_db.session.scalar.return_value = Mock()
mock_svc_instance = Mock()
mock_svc_instance.get_datasource_plugins.return_value = []
mock_svc_cls.return_value = mock_svc_instance
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=false"):
api = DatasourcePluginsApi()
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
assert status == 200
assert response == []
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=False
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_get_plugins_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import json
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
@ -112,7 +111,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
custom_disclaimer="",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
@ -138,7 +137,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
response = HumanInputFormApi().get("token-1")
body = json.loads(response.get_data(as_text=True))
body = response
assert set(body.keys()) == {
"site",
"form_content",
@ -167,7 +166,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
"description": "desc",
"copyright": None,
"privacy_policy": None,
"custom_disclaimer": None,
"custom_disclaimer": "",
"default_language": "en",
"prompt_public": False,
"show_workflow_steps": True,
@ -256,7 +255,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
custom_disclaimer="",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
@ -277,7 +276,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
response = HumanInputFormApi().get("token-1")
body = json.loads(response.get_data(as_text=True))
body = response
assert body["inputs"] == [input_config.model_dump(mode="json") for input_config in runtime_inputs]
service_mock.resolve_form_inputs.assert_called_once_with(form)
@ -380,7 +379,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
custom_disclaimer="",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
@ -403,7 +402,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
response = HumanInputFormApi().get("token-1")
body = json.loads(response.get_data(as_text=True))
body = response
assert set(body.keys()) == {
"site",
"form_content",
@ -432,7 +431,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
"description": "desc",
"copyright": None,
"privacy_policy": None,
"custom_disclaimer": None,
"custom_disclaimer": "",
"default_language": "en",
"prompt_public": False,
"show_workflow_steps": True,

View File

@ -1,8 +1,7 @@
from types import SimpleNamespace
from flask_restx import marshal
from fields.snippet_fields import snippet_list_fields
from fields.snippet_fields import SnippetListItemResponse
from libs.helper import dump_response
def test_snippet_list_fields_include_author_name() -> None:
@ -23,6 +22,6 @@ def test_snippet_list_fields_include_author_name() -> None:
updated_at=None,
)
result = marshal(snippet, snippet_list_fields)
result = dump_response(SnippetListItemResponse, snippet)
assert result["author_name"] == "Alice"

View File

@ -269,6 +269,7 @@ export type MessageDetailResponse = {
agent_thoughts?: Array<AgentThought>
annotation?: ConversationAnnotation | null
annotation_hit_history?: ConversationAnnotationHitHistory | null
answer: string
answer_tokens?: number | null
conversation_id: string
created_at?: number | null
@ -284,12 +285,11 @@ export type MessageDetailResponse = {
}
message?: JsonValue | null
message_files?: Array<MessageFile>
message_metadata_dict?: JsonValue | null
message_tokens?: number | null
metadata?: JsonValue | null
parent_message_id?: string | null
provider_response_latency?: number | null
query: string
re_sign_file_url_answer: string
status: string
workflow_run_id?: string | null
}
@ -732,7 +732,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number
@ -752,8 +751,8 @@ export type ConversationAnnotation = {
export type ConversationAnnotationHitHistory = {
annotation_create_account?: SimpleAccount | null
annotation_id: string
created_at?: number | null
id: string
}
export type HumanInputContent = {

View File

@ -628,7 +628,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),
@ -1060,8 +1059,8 @@ export const zConversationAnnotation = z.object({
*/
export const zConversationAnnotationHitHistory = z.object({
annotation_create_account: zSimpleAccount.nullish(),
annotation_id: z.string(),
created_at: z.int().nullish(),
id: z.string(),
})
/**
@ -2039,6 +2038,7 @@ export const zMessageDetailResponse = z.object({
agent_thoughts: z.array(zAgentThought).optional(),
annotation: zConversationAnnotation.nullish(),
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
answer: z.string(),
answer_tokens: z.int().nullish(),
conversation_id: z.string(),
created_at: z.int().nullish(),
@ -2052,12 +2052,11 @@ export const zMessageDetailResponse = z.object({
inputs: z.record(z.string(), zJsonValue),
message: zJsonValue.nullish(),
message_files: z.array(zMessageFile).optional(),
message_metadata_dict: zJsonValue.nullish(),
message_tokens: z.int().nullish(),
metadata: zJsonValue.nullish(),
parent_message_id: z.string().nullish(),
provider_response_latency: z.number().nullish(),
query: z.string(),
re_sign_file_url_answer: z.string(),
status: z.string(),
workflow_run_id: z.string().nullish(),
})

View File

@ -433,6 +433,7 @@ export type MessageDetailResponse = {
agent_thoughts?: Array<AgentThought>
annotation?: ConversationAnnotation | null
annotation_hit_history?: ConversationAnnotationHitHistory | null
answer: string
answer_tokens?: number | null
conversation_id: string
created_at?: number | null
@ -448,12 +449,11 @@ export type MessageDetailResponse = {
}
message?: JsonValue | null
message_files?: Array<MessageFile>
message_metadata_dict?: JsonValue | null
message_tokens?: number | null
metadata?: JsonValue | null
parent_message_id?: string | null
provider_response_latency?: number | null
query: string
re_sign_file_url_answer: string
status: string
workflow_run_id?: string | null
}
@ -1440,7 +1440,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number
@ -1460,8 +1459,8 @@ export type ConversationAnnotation = {
export type ConversationAnnotationHitHistory = {
annotation_create_account?: SimpleAccount | null
annotation_id: string
created_at?: number | null
id: string
}
export type HumanInputContent = {

View File

@ -1183,7 +1183,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),
@ -2286,8 +2285,8 @@ export const zConversationPagination = z.object({
*/
export const zConversationAnnotationHitHistory = z.object({
annotation_create_account: zSimpleAccount.nullish(),
annotation_id: z.string(),
created_at: z.int().nullish(),
id: z.string(),
})
/**
@ -3401,6 +3400,7 @@ export const zMessageDetailResponse = z.object({
agent_thoughts: z.array(zAgentThought).optional(),
annotation: zConversationAnnotation.nullish(),
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
answer: z.string(),
answer_tokens: z.int().nullish(),
conversation_id: z.string(),
created_at: z.int().nullish(),
@ -3414,12 +3414,11 @@ export const zMessageDetailResponse = z.object({
inputs: z.record(z.string(), zJsonValue),
message: zJsonValue.nullish(),
message_files: z.array(zMessageFile).optional(),
message_metadata_dict: zJsonValue.nullish(),
message_tokens: z.int().nullish(),
metadata: zJsonValue.nullish(),
parent_message_id: z.string().nullish(),
provider_response_latency: z.number().nullish(),
query: z.string(),
re_sign_file_url_answer: z.string(),
status: z.string(),
workflow_run_id: z.string().nullish(),
})

View File

@ -237,7 +237,6 @@ export type AgentThought = {
created_at?: number | null
files: Array<string>
id: string
message_chain_id?: string | null
message_id: string
observation?: string | null
position: number

View File

@ -257,7 +257,6 @@ export const zAgentThought = z.object({
created_at: z.int().nullish(),
files: z.array(z.string()),
id: z.string(),
message_chain_id: z.string().nullish(),
message_id: z.string(),
observation: z.string().nullish(),
position: z.int(),

View File

@ -217,14 +217,8 @@ const toFeedback = (feedback: NonNullable<MessageDetailResponse['feedbacks']>[nu
}
}
type AgentDebugMessageWithLegacyAnswer = MessageDetailResponse & {
answer?: string | null
}
const getAgentDebugMessageAnswer = (message: MessageDetailResponse) => {
const legacyAnswer = (message as AgentDebugMessageWithLegacyAnswer).answer
return message.re_sign_file_url_answer ?? legacyAnswer ?? ''
return message.answer ?? ''
}
function getFormattedAgentDebugChatTree(messages: MessageDetailResponse[]): ChatItemInTree[] {