diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 195a41f2888..f987ecca745 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -167,12 +167,16 @@ register_schema_models( ChatMessagesQuery, MessageFeedbackPayload, FeedbackExportQuery, +) +register_response_schema_models( + console_ns, AnnotationCountResponse, SuggestedQuestionsResponse, MessageDetailResponse, MessageInfiniteScrollPaginationResponse, + SimpleResultResponse, + TextFileResponse, ) -register_response_schema_models(console_ns, SimpleResultResponse, TextFileResponse) @console_ns.route("/apps//chat-messages") diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index a6a61262cdc..63fc16d4352 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -12,6 +12,7 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro from controllers.common.fields import GeneratedAppResponse from controllers.common.schema import ( query_params_from_model, + query_params_from_request, register_response_schema_models, register_schema_model, register_schema_models, @@ -150,12 +151,11 @@ class DatasourcePluginsApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - # Get query parameter to determine published or draft - is_published: bool = request.args.get("is_published", default=True, type=bool) + query = query_params_from_request(DatasourcePluginsQuery) rag_pipeline_service: RagPipelineService = RagPipelineService() datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( - tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published + tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=query.is_published ) return datasource_plugins, 200 diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index 4bd828c8df4..760ddd9d747 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -13435,7 +13435,6 @@ Soft lifecycle state for Agent records. | created_at | integer | | No | | files | [ string ] | | Yes | | id | string | | Yes | -| message_chain_id | string | | No | | message_id | string | | Yes | | observation | string | | No | | position | integer | | Yes | @@ -14567,8 +14566,8 @@ Enum class for configurate method of provider model. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | annotation_create_account | [SimpleAccount](#simpleaccount) | | No | +| annotation_id | string | | Yes | | created_at | integer | | No | -| id | string | | Yes | #### ConversationDetail @@ -17072,6 +17071,7 @@ Enum class for large language model mode. | agent_thoughts | [ [AgentThought](#agentthought) ] | | No | | annotation | [ConversationAnnotation](#conversationannotation) | | No | | annotation_hit_history | [ConversationAnnotationHitHistory](#conversationannotationhithistory) | | No | +| answer | string | | Yes | | answer_tokens | integer | | No | | conversation_id | string | | Yes | | created_at | integer | | No | @@ -17085,12 +17085,11 @@ Enum class for large language model mode. | inputs | object | | Yes | | message | [JSONValue](#jsonvalue) | | No | | message_files | [ [MessageFile](#messagefile) ] | | No | -| message_metadata_dict | [JSONValue](#jsonvalue) | | No | | message_tokens | integer | | No | +| metadata | [JSONValue](#jsonvalue) | | No | | parent_message_id | string | | No | | provider_response_latency | number | | No | | query | string | | Yes | -| re_sign_file_url_answer | string | | Yes | | status | string | | Yes | | workflow_run_id | string | | No | diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index bdec903ef33..4b8186f3017 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -607,7 +607,11 @@ class TestMiscApis: method = unwrap(api.get) service = MagicMock() - service.get_recommended_plugins.return_value = [{"id": "p1"}] + recommended_plugins = { + "installed_recommended_plugins": [{"id": "p1"}], + "uninstalled_recommended_plugins": [{"id": "p2"}], + } + service.get_recommended_plugins.return_value = recommended_plugins user = make_account() tenant_id = "tenant-1" @@ -619,7 +623,7 @@ class TestMiscApis: ), ): result = method(api, tenant_id, user) - assert result == [{"id": "p1"}] + assert result == recommended_plugins service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id) @@ -826,7 +830,7 @@ class TestRagPipelineByIdApi: result = method(api, pipeline, "old-workflow") workflow_service.delete_workflow.assert_called_once() - assert result == (None, 204) + assert result == ("", 204) def test_delete_active_workflow_rejected(self, app: Flask) -> None: api = RagPipelineByIdApi() diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py index 9adb26ff3d2..4fc99cdc74c 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -2,21 +2,22 @@ from __future__ import annotations -from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden -from controllers.web.site import AppSiteApi, AppSiteInfo +from controllers.web.site import AppSiteApi, WebAppSiteResponse, WebModelConfigResponse from models import Tenant, TenantStatus -from models.model import App, AppMode, CustomizeTokenStrategy, Site +from models.account import TenantCustomConfigDict +from models.model import App, AppMode, AppModelConfig, CustomizeTokenStrategy, EndUser, Site +from services.feature_service import FeatureModel @pytest.fixture -def app(flask_app_with_containers) -> Flask: +def app(flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers @@ -41,7 +42,23 @@ def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True def _create_site(db_session: Session, app_id: str) -> Site: - site = Site( + site = _site_model(app_id=app_id) + db_session.add(site) + db_session.commit() + return site + + +def _end_user(tenant_id: str, app_id: str) -> EndUser: + return EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="browser", + session_id=f"session-{app_id}", + ) + + +def _site_model(*, app_id: str) -> Site: + return Site( app_id=app_id, title="Site", icon_type="emoji", @@ -51,31 +68,30 @@ def _create_site(db_session: Session, app_id: str) -> Site: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, + custom_disclaimer="", customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) - db_session.add(site) - db_session.commit() - return site class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: + def test_happy_path(self, mock_features: MagicMock, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" tenant = _create_tenant(db_session_with_containers) app_model = _create_app(db_session_with_containers, tenant.id) _create_site(db_session_with_containers, app_model.id) - end_user = SimpleNamespace(id="eu-1") - mock_features.return_value = SimpleNamespace(can_replace_logo=False) + end_user = _end_user(tenant.id, app_model.id) + mock_features.return_value = FeatureModel(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) assert result["app_id"] == app_model.id + assert result["end_user_id"] == end_user.id assert result["plan"] == "basic" assert result["enable_site"] is True @@ -83,51 +99,139 @@ class TestAppSiteApi: app.config["RESTX_MASK_HEADER"] = "X-Fields" tenant = _create_tenant(db_session_with_containers) app_model = _create_app(db_session_with_containers, tenant.id) - end_user = SimpleNamespace(id="eu-1") + end_user = _end_user(tenant.id, app_model.id) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.FeatureService.get_features") - def test_archived_tenant_raises_forbidden( - self, mock_features, app: Flask, db_session_with_containers: Session - ) -> None: + def test_archived_tenant_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) app_model = _create_app(db_session_with_containers, tenant.id) _create_site(db_session_with_containers, app_model.id) - end_user = SimpleNamespace(id="eu-1") - mock_features.return_value = SimpleNamespace(can_replace_logo=False) + end_user = _end_user(tenant.id, app_model.id) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -class TestAppSiteInfo: +def _tenant_model(*, plan: str = "basic", custom_config: TenantCustomConfigDict | None = None) -> Tenant: + tenant = Tenant(name="test-tenant", plan=plan) + tenant.custom_config_dict = custom_config or {} + return tenant + + +def _app_model(*, tenant: Tenant, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant.id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, + ) + app_model.id = "app-test" + return app_model + + +class TestWebAppSiteResponse: def test_basic_fields(self) -> None: - tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) - site_obj = SimpleNamespace() - info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) - - assert info.app_id == "app-1" - assert info.end_user_id == "eu-1" - assert info.enable_site is True - assert info.plan == "basic" - assert info.can_replace_logo is False - assert info.model_config is None - - @patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com")) - def test_can_replace_logo_sets_custom_config(self) -> None: - tenant = SimpleNamespace( - id="tenant-1", - plan="pro", - custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, + tenant = _tenant_model() + app_model = _app_model(tenant=tenant) + response = WebAppSiteResponse.from_app_site( + tenant=tenant, + app_model=app_model, + site=_site_model(app_id=app_model.id), + end_user_id="eu-1", + can_replace_logo=False, ) - site_obj = SimpleNamespace() - info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) - assert info.can_replace_logo is True - assert info.custom_config["remove_webapp_brand"] is True - assert "webapp-logo" in info.custom_config["replace_webapp_logo"] + assert response.app_id == app_model.id + assert response.end_user_id == "eu-1" + assert response.enable_site is True + assert response.plan == "basic" + assert response.can_replace_logo is False + assert response.model_config_ is None + assert response.custom_config is None + assert response.site.custom_disclaimer == "" + + def test_nullable_site_fields_preserve_none(self) -> None: + tenant = _tenant_model() + app_model = _app_model(tenant=tenant) + site = _site_model(app_id=app_model.id) + site.chat_color_theme = None + site.icon_type = None + site.icon = None + site.icon_background = None + site.description = None + site.copyright = None + site.privacy_policy = None + + response = WebAppSiteResponse.from_app_site( + tenant=tenant, + app_model=app_model, + site=site, + end_user_id=None, + can_replace_logo=False, + ) + + dumped = response.model_dump(mode="json") + assert dumped["end_user_id"] is None + assert dumped["site"]["chat_color_theme"] is None + assert dumped["site"]["icon_type"] is None + assert dumped["site"]["icon"] is None + assert dumped["site"]["icon_background"] is None + assert dumped["site"]["description"] is None + assert dumped["site"]["copyright"] is None + assert dumped["site"]["privacy_policy"] is None + assert dumped["site"]["custom_disclaimer"] == "" + + @patch("controllers.web.site.dify_config.FILES_URL", "https://files.example.com") + def test_can_replace_logo_sets_custom_config(self) -> None: + tenant = _tenant_model( + plan="pro", + custom_config={"remove_webapp_brand": True, "replace_webapp_logo": "enabled"}, + ) + app_model = _app_model(tenant=tenant) + response = WebAppSiteResponse.from_app_site( + tenant=tenant, + app_model=app_model, + site=_site_model(app_id=app_model.id), + end_user_id="eu-1", + can_replace_logo=True, + ) + + assert response.can_replace_logo is True + assert response.custom_config is not None + assert response.custom_config.remove_webapp_brand is True + assert response.custom_config.replace_webapp_logo is not None + assert "webapp-logo" in response.custom_config.replace_webapp_logo + + +class TestWebModelConfigResponse: + def test_serializes_internal_model_config_properties_to_public_keys(self) -> None: + model_config = AppModelConfig( + app_id="app-test", + opening_statement="Hello", + suggested_questions='["Question?"]', + suggested_questions_after_answer='{"enabled": true}', + more_like_this='{"enabled": false}', + model='{"provider": "openai", "name": "gpt-4o", "mode": "chat"}', + user_input_form='[{"text-input": {"label": "Name", "variable": "name", "required": true}}]', + pre_prompt="System prompt", + created_by="account-1", + updated_by="account-1", + ) + + dumped = WebModelConfigResponse.model_validate(model_config, from_attributes=True).model_dump(mode="json") + + assert dumped == { + "opening_statement": "Hello", + "suggested_questions": ["Question?"], + "suggested_questions_after_answer": {"enabled": True}, + "more_like_this": {"enabled": False}, + "model": {"provider": "openai", "name": "gpt-4o", "mode": "chat"}, + "user_input_form": [{"text-input": {"label": "Name", "variable": "name", "required": True}}], + "pre_prompt": "System prompt", + } diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py index f04ab6d6e7c..9c47f8e5a31 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py @@ -77,9 +77,11 @@ def test_human_input_preview_delegates_to_service( preview_payload = { "form_id": "node-42", + "node_id": "node-42", + "node_title": "Human Input", "form_content": "
example
", "inputs": [{"name": "topic"}], - "actions": [{"id": "continue"}], + "actions": [{"id": "continue", "title": "Continue"}], } service_instance = MagicMock() service_instance.get_human_input_form_preview.return_value = preview_payload @@ -88,7 +90,15 @@ def test_human_input_preview_delegates_to_service( with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}): response = case.resource_cls().post(app_id=app_model.id, node_id="node-42") - assert response == preview_payload + assert response == { + **preview_payload, + "TYPE": "human_input_required", + "actions": [{"id": "continue", "title": "Continue", "button_style": "default"}], + "resolved_default_values": {}, + "display_in_ui": False, + "form_token": None, + "expiration_time": None, + } service_instance.get_human_input_form_preview.assert_called_once_with( app_model=app_model, account=account, diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index b7e16b91fb7..12671458f70 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,4 +1,6 @@ import inspect +from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -16,6 +18,7 @@ from controllers.console.datasets.external import ( ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) +from extensions.ext_database import db from models.account import Account, TenantAccountRole from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService @@ -38,28 +41,183 @@ def current_user() -> Account: return user +def _external_api_dict(api_id: str = "api-1") -> dict: + return { + "id": api_id, + "tenant_id": "tenant-1", + "name": f"External API {api_id}", + "description": f"Description for {api_id}", + "settings": { + "endpoint": f"https://external.example.com/{api_id}", + "api_key": "secret", + "headers": {"X-Source": "unit-test"}, + "timeout": 30, + }, + "dataset_bindings": [ + {"id": f"dataset-{api_id}", "name": f"Dataset {api_id}"}, + ], + "created_by": "user-1", + "created_at": "2024-01-01T00:00:00", + } + + +def _external_api_object(api_id: str = "api-1") -> SimpleNamespace: + payload = _external_api_dict(api_id) + return SimpleNamespace( + **{ + **payload, + "dataset_bindings": [SimpleNamespace(**binding) for binding in payload["dataset_bindings"]], + } + ) + + +def _expected_dataset_detail_payload() -> dict[str, Any]: + return { + "id": "dataset-1", + "name": "Support knowledge", + "description": "External support articles", + "provider": "external", + "permission": "only_me", + "data_source_type": "external", + "indexing_technique": "economy", + "app_count": 2, + "document_count": 7, + "word_count": 2048, + "created_by": "user-1", + "author_name": "Test User", + "created_at": 1710000000, + "updated_by": "user-2", + "updated_at": 1710003600, + "embedding_model": None, + "embedding_model_provider": None, + "embedding_available": False, + "retrieval_model_dict": { + "search_method": "semantic_search", + "reranking_enable": False, + "reranking_mode": None, + "reranking_model": {"reranking_provider_name": None, "reranking_model_name": None}, + "weights": None, + "top_k": 4, + "score_threshold_enabled": True, + "score_threshold": 0.5, + }, + "summary_index_setting": { + "enable": True, + "model_name": "summary-model", + "model_provider_name": "provider-a", + "summary_prompt": "Summarize this.", + }, + "tags": [{"id": "tag-1", "name": "Support", "type": "knowledge"}], + "doc_form": "text_model", + "external_knowledge_info": { + "external_knowledge_id": "knowledge-1", + "external_knowledge_api_id": "api-1", + "external_knowledge_api_name": "External API api-1", + "external_knowledge_api_endpoint": "https://external.example.com/api-1", + }, + "external_retrieval_model": { + "top_k": 4, + "score_threshold": 0.5, + "score_threshold_enabled": True, + }, + "doc_metadata": [{"id": "metadata-1", "name": "source", "type": "string"}], + "built_in_field_enabled": True, + "pipeline_id": None, + "runtime_mode": "external", + "chunk_structure": "general", + "icon_info": { + "icon_type": "emoji", + "icon": "book", + "icon_background": "#FFF4ED", + "icon_url": None, + }, + "is_published": True, + "total_documents": 7, + "total_available_documents": 6, + "enable_api": True, + "is_multimodal": False, + "maintainer": None, + "permission_keys": [], + } + + +def _dataset_detail_object() -> SimpleNamespace: + payload = _expected_dataset_detail_payload() + return SimpleNamespace( + **{ + **payload, + "summary_index_setting": SimpleNamespace(**payload["summary_index_setting"]), + "tags": [SimpleNamespace(**tag) for tag in payload["tags"]], + "external_knowledge_info": SimpleNamespace(**payload["external_knowledge_info"]), + "external_retrieval_model": SimpleNamespace(**payload["external_retrieval_model"]), + "doc_metadata": [SimpleNamespace(**item) for item in payload["doc_metadata"]], + "icon_info": SimpleNamespace(**payload["icon_info"]), + } + ) + + class TestExternalApiTemplateListApi: def test_get_success(self, app: Flask): api = ExternalApiTemplateListApi() method = inspect.unwrap(api.get) - api_item = MagicMock() - api_item.to_dict.return_value = {"id": "1"} + api_item = _external_api_object("api-1") with ( - app.test_request_context("/?page=1&limit=20"), + app.test_request_context("/?page=2&limit=1&keyword=vector"), patch.object( ExternalDatasetService, "get_external_knowledge_apis", - return_value=([api_item], 1), + return_value=([api_item], 3), ) as get_external_knowledge_apis, ): resp, status = method(api, "tenant-1") assert status == 200 - assert resp["total"] == 1 - assert resp["data"][0]["id"] == "1" - get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) + assert resp == { + "data": [_external_api_dict("api-1")], + "has_more": True, + "limit": 1, + "total": 3, + "page": 2, + } + get_external_knowledge_apis.assert_called_once_with(2, 1, "tenant-1", "vector") + + def test_post_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account): + api = ExternalApiTemplateListApi() + method = inspect.unwrap(api.post) + + payload = { + "name": "Vendor Search", + "settings": { + "endpoint": "https://external.example.com/search", + "api_key": "secret", + "headers": {"X-Source": "unit-test"}, + "timeout": 30, + }, + } + created = _external_api_object("api-created") + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list, + patch.object( + ExternalDatasetService, + "create_external_knowledge_api", + return_value=created, + ) as create_external_knowledge_api, + ): + resp, status = method(api, "tenant-1", current_user) + + assert status == 201 + assert resp == _external_api_dict("api-created") + validate_api_list.assert_called_once_with(payload["settings"]) + create_external_knowledge_api.assert_called_once_with( + tenant_id="tenant-1", + user_id="user-1", + args=payload, + ) def test_post_forbidden(self, app: Flask, current_user: Account): current_user.role = TenantAccountRole.NORMAL @@ -97,6 +255,25 @@ class TestExternalApiTemplateListApi: class TestExternalApiTemplateApi: + def test_get_success_returns_template_contract(self, app: Flask): + api = ExternalApiTemplateApi() + method = inspect.unwrap(api.get) + template = _external_api_object("api-detail") + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_api", + return_value=template, + ) as get_external_knowledge_api, + ): + resp, status = method(api, "tenant-1", "api-detail") + + assert status == 200 + assert resp == _external_api_dict("api-detail") + get_external_knowledge_api.assert_called_once_with("api-detail", "tenant-1") + def test_get_not_found(self, app: Flask): api = ExternalApiTemplateApi() method = inspect.unwrap(api.get) @@ -112,6 +289,42 @@ class TestExternalApiTemplateApi: with pytest.raises(NotFound): method(api, "tenant-1", "api-id") + def test_patch_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account): + api = ExternalApiTemplateApi() + method = inspect.unwrap(api.patch) + + payload = { + "name": "Updated API", + "settings": { + "endpoint": "https://external.example.com/updated", + "api_key": "new-secret", + "headers": {"X-Version": "2"}, + }, + } + updated = _external_api_object("api-updated") + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list, + patch.object( + ExternalDatasetService, + "update_external_knowledge_api", + return_value=updated, + ) as update_external_knowledge_api, + ): + resp, status = method(api, "tenant-1", current_user, "api-updated") + + assert status == 200 + assert resp == _external_api_dict("api-updated") + validate_api_list.assert_called_once_with(payload["settings"]) + update_external_knowledge_api.assert_called_once_with( + tenant_id="tenant-1", + user_id="user-1", + external_knowledge_api_id="api-updated", + args=payload, + ) + def test_delete_forbidden(self, app: Flask, current_user: Account): current_user.role = TenantAccountRole.NORMAL @@ -149,45 +362,37 @@ class TestExternalDatasetCreateApi: method = inspect.unwrap(api.post) payload = { - "external_knowledge_api_id": "api", - "external_knowledge_id": "kid", - "name": "dataset", + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "knowledge-1", + "name": "Support knowledge", + "description": "External support articles", + "external_retrieval_model": { + "top_k": 4, + "score_threshold": 0.5, + "score_threshold_enabled": True, + }, } - dataset = MagicMock() - - dataset.embedding_available = False - dataset.built_in_field_enabled = False - dataset.is_published = False - dataset.enable_api = False - dataset.enable_qa = False - dataset.enable_vector_store = False - dataset.vector_store_setting = None - dataset.is_multimodal = False - - dataset.retrieval_model_dict = {} - dataset.tags = [] - dataset.external_knowledge_info = None - dataset.external_retrieval_model = None - dataset.doc_metadata = [] - dataset.icon_info = None - dataset.permission_keys = [] - - dataset.summary_index_setting = MagicMock() - dataset.summary_index_setting.enable = False + dataset = _dataset_detail_object() with ( - app.test_request_context("/"), + app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object( ExternalDatasetService, "create_external_dataset", return_value=dataset, - ), + ) as create_external_dataset, ): - _, status = method(api, "tenant-1", current_user) + resp, status = method(api, "tenant-1", current_user) assert status == 201 + assert resp == _expected_dataset_detail_payload() + create_external_dataset.assert_called_once_with( + tenant_id="tenant-1", + user_id="user-1", + args=payload, + ) def test_create_forbidden(self, app: Flask, current_user: Account): current_user.role = TenantAccountRole.NORMAL @@ -228,24 +433,58 @@ class TestExternalKnowledgeHitTestingApi: api = ExternalKnowledgeHitTestingApi() method = inspect.unwrap(api.post) - payload = {"query": "hello"} + payload = { + "query": "hello", + "external_retrieval_model": { + "top_k": 3, + "score_threshold": 0.25, + "score_threshold_enabled": True, + }, + "metadata_filtering_conditions": { + "logical_operator": "and", + "conditions": [{"name": "source", "comparison_operator": "contains", "value": "external"}], + }, + } dataset = MagicMock() + retrieve_response = { + "query": {"content": "hello"}, + "records": [ + { + "content": "answer", + "title": "doc", + "score": 0.9, + "metadata": {"source": "external", "page": 2}, + } + ], + } with ( - app.test_request_context("/"), + app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object(DatasetService, "get_dataset", return_value=dataset), - patch.object(DatasetService, "check_dataset_permission"), + patch.object(DatasetService, "check_dataset_permission") as check_dataset_permission, + patch.object(HitTestingService, "hit_testing_args_check") as hit_testing_args_check, patch.object( HitTestingService, "external_retrieve", - return_value={"ok": True}, - ), + return_value=retrieve_response, + ) as external_retrieve, + patch("controllers.console.datasets.external.dump_response", side_effect=lambda _model, value: value), ): resp = method(api, current_user, "dataset-id") - assert resp["ok"] is True + assert resp == retrieve_response + check_dataset_permission.assert_called_once_with(dataset, current_user) + hit_testing_args_check.assert_called_once_with(payload) + external_retrieve.assert_called_once_with( + session=db.session, + dataset=dataset, + query="hello", + account=current_user, + external_retrieval_model=payload["external_retrieval_model"], + metadata_filtering_conditions=payload["metadata_filtering_conditions"], + ) class TestBedrockRetrievalApi: @@ -254,24 +493,44 @@ class TestBedrockRetrievalApi: method = inspect.unwrap(api.post) payload = { - "retrieval_setting": {}, - "query": "hello", - "knowledge_id": "kid", + "retrieval_setting": {"top_k": 5, "score_threshold": 0.72}, + "query": "hello bedrock", + "knowledge_id": "knowledge-base-1", + } + retrieval_response = { + "records": [ + { + "metadata": {"source": "bedrock", "uri": "s3://bucket/doc.txt"}, + "score": 0.8, + "title": "doc", + "content": "answer", + }, + { + "metadata": {"source": "bedrock", "uri": "s3://bucket/other.txt"}, + "score": 0.65, + "title": None, + "content": None, + }, + ] } with ( - app.test_request_context("/"), + app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object( ExternalDatasetTestService, "knowledge_retrieval", - return_value={"ok": True}, - ), + return_value=retrieval_response, + ) as knowledge_retrieval, ): resp, status = method() assert status == 200 - assert resp["ok"] is True + assert resp == retrieval_response + retrieval_setting, query, knowledge_id = knowledge_retrieval.call_args.args + assert retrieval_setting.model_dump() == payload["retrieval_setting"] + assert query == "hello bedrock" + assert knowledge_id == "knowledge-base-1" class TestExternalApiTemplateListApiAdvanced: @@ -297,10 +556,10 @@ class TestExternalApiTemplateListApiAdvanced: api = ExternalApiTemplateListApi() method = inspect.unwrap(api.get) - templates = [MagicMock(id=f"api-{i}") for i in range(3)] + templates = [_external_api_object(f"api-{i}") for i in range(3)] with ( - app.test_request_context("/?page=1&limit=20"), + app.test_request_context("/?page=2&limit=3"), patch( "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", return_value=(templates, 25), @@ -309,9 +568,14 @@ class TestExternalApiTemplateListApiAdvanced: resp, status = method(api, "tenant-1") assert status == 200 - assert resp["total"] == 25 - assert len(resp["data"]) == 3 - get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) + assert resp == { + "data": [_external_api_dict(f"api-{i}") for i in range(3)], + "has_more": True, + "limit": 3, + "total": 25, + "page": 2, + } + get_external_knowledge_apis.assert_called_once_with(2, 3, "tenant-1", None) class TestExternalDatasetCreateApiAdvanced: @@ -374,15 +638,46 @@ class TestExternalKnowledgeHitTestingApiAdvanced: "controllers.console.datasets.external.DatasetService.get_dataset", return_value=dataset, ), - patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"), + patch("controllers.console.datasets.external.DatasetService.check_dataset_permission") as check_permission, + patch("controllers.console.datasets.external.HitTestingService.hit_testing_args_check") as args_check, patch( "controllers.console.datasets.external.HitTestingService.external_retrieve", - return_value={"results": []}, - ), + return_value={ + "query": {"content": "test query"}, + "records": [ + { + "content": None, + "title": "metadata-only", + "score": None, + "metadata": {"status": "active"}, + } + ], + }, + ) as external_retrieve, ): resp = method(api, current_user, "ds-1") - assert resp["results"] == [] + assert resp == { + "query": {"content": "test query"}, + "records": [ + { + "content": None, + "title": "metadata-only", + "score": None, + "metadata": {"status": "active"}, + } + ], + } + check_permission.assert_called_once_with(dataset, current_user) + args_check.assert_called_once_with(payload) + external_retrieve.assert_called_once_with( + session=db.session, + dataset=dataset, + query="test query", + account=current_user, + external_retrieval_model={"type": "bm25"}, + metadata_filtering_conditions={"status": "active"}, + ) class TestBedrockRetrievalApiAdvanced: diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index be68a3beed6..2ac9fc978d8 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -1,4 +1,4 @@ -from inspect import unwrap as inspect_unwrap +from inspect import unwrap from io import BytesIO from typing import Any from unittest.mock import MagicMock, patch @@ -35,24 +35,16 @@ from models.model import AppMode from services.errors.conversation import ConversationNotExistsError from services.errors.llm import InvokeRateLimitError -unwrap: Any = inspect_unwrap - @pytest.fixture -def account() -> Account: - acc = Account(name="User", email="user@example.com") +def account(): + acc = MagicMock(spec=Account) acc.id = "u1" return acc -def _file_data() -> Any: - file_data: Any = BytesIO(b"fake audio data") - file_data.filename = "test.wav" - return file_data - - @pytest.fixture -def trial_app_chat() -> MagicMock: +def trial_app_chat(): app = MagicMock() app.id = "a-chat" app.mode = AppMode.CHAT @@ -60,7 +52,7 @@ def trial_app_chat() -> MagicMock: @pytest.fixture -def trial_app_completion() -> MagicMock: +def trial_app_completion(): app = MagicMock() app.id = "a-comp" app.mode = AppMode.COMPLETION @@ -68,7 +60,7 @@ def trial_app_completion() -> MagicMock: @pytest.fixture -def trial_app_workflow() -> MagicMock: +def trial_app_workflow(): app = MagicMock() app.id = "a-workflow" app.mode = AppMode.WORKFLOW @@ -76,7 +68,7 @@ def trial_app_workflow() -> MagicMock: @pytest.fixture -def valid_parameters() -> dict[str, object]: +def valid_parameters(): return { "user_input_form": [], "system_parameters": {}, @@ -92,13 +84,54 @@ def valid_parameters() -> dict[str, object]: } -def test_trial_workflow_uses_trial_scoped_simple_account_model() -> None: - assert module.simple_account_model.name == "TrialSimpleAccount" - assert hasattr(module.simple_account_model, "items") +def test_trial_workflow_registers_normalized_simple_account_response_model(): + assert "SimpleAccountResponse" in module.console_ns.models + + +def _response_model_name(entry: object) -> str: + assert isinstance(entry, tuple) + assert len(entry) >= 2 + model = entry[1] + name = getattr(model, "name", None) + assert isinstance(name, str) + return name + + +def test_trial_endpoints_keep_response_and_query_docs(): + untyped_generated_response_views = [ + module.TrialAppWorkflowRunApi.post, + module.TrialChatApi.post, + module.TrialCompletionApi.post, + ] + for view in untyped_generated_response_views: + apidoc = getattr(view, "__apidoc__", {}) + assert apidoc.get("responses", {})["200"] == ("Success", None, {}) + + cases = [ + (module.TrialMessageSuggestedQuestionApi.get, module.SuggestedQuestionsResponse.__name__), + (module.TrialChatAudioApi.post, module.AudioTranscriptResponse.__name__), + (module.TrialChatTextApi.post, module.AudioBinaryResponse.__name__), + (module.TrialSitApi.get, module.SiteResponse.__name__), + (module.TrialAppParameterApi.get, module.ParametersResponse.__name__), + (module.AppApi.get, module.AppDetailWithSite.__name__), + (module.AppWorkflowApi.get, module.WorkflowResponse.__name__), + (module.DatasetListApi.get, module.TrialDatasetListResponse.__name__), + ] + + for view, model_name in cases: + apidoc = getattr(view, "__apidoc__", {}) + responses = apidoc.get("responses", {}) + assert _response_model_name(responses["200"]) == model_name + + dataset_params = module.DatasetListApi.get.__apidoc__["params"] + assert dataset_params["ids"]["in"] == "query" + assert dataset_params["ids"]["type"] == "array" + assert dataset_params["page"]["default"] == 1 + assert dataset_params["limit"]["default"] == 20 class TestTrialAppWorkflowRunApi: - def test_not_workflow_app(self, app: Flask, account: Account) -> None: + def test_not_workflow_app(self, app: Flask, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -106,7 +139,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(NotWorkflowAppError): method(api, account, MagicMock(mode=AppMode.CHAT)) - def test_success(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -119,7 +152,7 @@ class TestTrialAppWorkflowRunApi: assert result is not None - def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -134,7 +167,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_workflow) - def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -149,7 +182,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(ProviderQuotaExceededError): method(api, account, trial_app_workflow) - def test_workflow_model_not_support(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_model_not_support(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -164,7 +197,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(api, account, trial_app_workflow) - def test_workflow_invoke_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_invoke_error(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -179,7 +212,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(CompletionRequestError): method(api, account, trial_app_workflow) - def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -194,7 +227,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(InvokeRateLimitHttpError): method(api, account, trial_app_workflow) - def test_workflow_value_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_value_error(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -209,7 +242,7 @@ class TestTrialAppWorkflowRunApi: with pytest.raises(ValueError): method(api, account, trial_app_workflow) - def test_workflow_generic_exception(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: + def test_workflow_generic_exception(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -226,7 +259,7 @@ class TestTrialAppWorkflowRunApi: class TestTrialChatApi: - def test_not_chat_app(self, app: Flask, account: Account) -> None: + def test_not_chat_app(self, app: Flask, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -234,7 +267,7 @@ class TestTrialChatApi: with pytest.raises(NotChatAppError): method(api, account, MagicMock(mode="completion")) - def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -247,7 +280,7 @@ class TestTrialChatApi: assert result is not None - def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -262,7 +295,7 @@ class TestTrialChatApi: with pytest.raises(NotFound): method(api, account, trial_app_chat) - def test_chat_conversation_completed(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_conversation_completed(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -277,7 +310,7 @@ class TestTrialChatApi: with pytest.raises(ConversationCompletedError): method(api, account, trial_app_chat) - def test_chat_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_app_config_broken(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -292,7 +325,7 @@ class TestTrialChatApi: with pytest.raises(AppUnavailableError): method(api, account, trial_app_chat) - def test_chat_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_provider_not_init(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -307,7 +340,7 @@ class TestTrialChatApi: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_chat) - def test_chat_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_quota_exceeded(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -322,7 +355,7 @@ class TestTrialChatApi: with pytest.raises(ProviderQuotaExceededError): method(api, account, trial_app_chat) - def test_chat_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_model_not_support(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -337,7 +370,7 @@ class TestTrialChatApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(api, account, trial_app_chat) - def test_chat_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_invoke_error(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -352,7 +385,7 @@ class TestTrialChatApi: with pytest.raises(CompletionRequestError): method(api, account, trial_app_chat) - def test_chat_rate_limit_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_rate_limit_error(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -367,7 +400,7 @@ class TestTrialChatApi: with pytest.raises(InvokeRateLimitHttpError): method(api, account, trial_app_chat) - def test_chat_value_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_value_error(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -382,7 +415,7 @@ class TestTrialChatApi: with pytest.raises(ValueError): method(api, account, trial_app_chat) - def test_chat_generic_exception(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_chat_generic_exception(self, app: Flask, trial_app_chat, account): api = module.TrialChatApi() method = unwrap(api.post) @@ -399,7 +432,7 @@ class TestTrialChatApi: class TestTrialCompletionApi: - def test_not_completion_app(self, app: Flask, account: Account) -> None: + def test_not_completion_app(self, app: Flask, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -407,7 +440,7 @@ class TestTrialCompletionApi: with pytest.raises(NotCompletionAppError): method(api, account, MagicMock(mode=AppMode.CHAT)) - def test_success(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -420,7 +453,7 @@ class TestTrialCompletionApi: assert result is not None - def test_completion_app_config_broken(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_app_config_broken(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -435,7 +468,7 @@ class TestTrialCompletionApi: with pytest.raises(AppUnavailableError): method(api, account, trial_app_completion) - def test_completion_provider_not_init(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_provider_not_init(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -450,7 +483,7 @@ class TestTrialCompletionApi: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_completion) - def test_completion_quota_exceeded(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_quota_exceeded(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -465,7 +498,7 @@ class TestTrialCompletionApi: with pytest.raises(ProviderQuotaExceededError): method(api, account, trial_app_completion) - def test_completion_model_not_support(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_model_not_support(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -480,7 +513,7 @@ class TestTrialCompletionApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(api, account, trial_app_completion) - def test_completion_invoke_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_invoke_error(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -495,7 +528,7 @@ class TestTrialCompletionApi: with pytest.raises(CompletionRequestError): method(api, account, trial_app_completion) - def test_completion_rate_limit_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_rate_limit_error(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -510,7 +543,7 @@ class TestTrialCompletionApi: with pytest.raises(InternalServerError): method(api, account, trial_app_completion) - def test_completion_value_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_value_error(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -525,7 +558,7 @@ class TestTrialCompletionApi: with pytest.raises(ValueError): method(api, account, trial_app_completion) - def test_completion_generic_exception(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: + def test_completion_generic_exception(self, app: Flask, trial_app_completion, account): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -542,7 +575,7 @@ class TestTrialCompletionApi: class TestTrialMessageSuggestedQuestionApi: - def test_not_chat_app(self, app: Flask, account: Account) -> None: + def test_not_chat_app(self, app: Flask, account): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) @@ -550,7 +583,7 @@ class TestTrialMessageSuggestedQuestionApi: with pytest.raises(NotChatAppError): method(api, account, MagicMock(mode="completion"), str(uuid4())) - def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) @@ -566,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi: assert result == {"data": ["q1", "q2"]} - def test_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_conversation_not_exists(self, app: Flask, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) @@ -583,14 +616,14 @@ class TestTrialMessageSuggestedQuestionApi: class TestTrialAppParameterApi: - def test_app_unavailable(self) -> None: + def test_app_unavailable(self): api = module.TrialAppParameterApi() method = unwrap(api.get) with pytest.raises(AppUnavailableError): method(api, None) - def test_success_non_workflow(self, valid_parameters: dict[str, object]) -> None: + def test_success_non_workflow(self, valid_parameters): api = module.TrialAppParameterApi() method = unwrap(api.get) @@ -617,11 +650,12 @@ class TestTrialAppParameterApi: class TestTrialChatAudioApi: - def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -634,11 +668,12 @@ class TestTrialChatAudioApi: assert result == {"text": "hello"} - def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_app_config_broken(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -653,11 +688,12 @@ class TestTrialChatAudioApi: with pytest.raises(module.AppUnavailableError): method(api, account, trial_app_chat) - def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -672,11 +708,12 @@ class TestTrialChatAudioApi: with pytest.raises(module.NoAudioUploadedError): method(api, account, trial_app_chat) - def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_audio_too_large(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -691,11 +728,12 @@ class TestTrialChatAudioApi: with pytest.raises(module.AudioTooLargeError): method(api, account, trial_app_chat) - def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -710,11 +748,12 @@ class TestTrialChatAudioApi: with pytest.raises(module.UnsupportedAudioTypeError): method(api, account, trial_app_chat) - def test_provider_not_support_tts(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_provider_not_support_tts(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -729,11 +768,12 @@ class TestTrialChatAudioApi: with pytest.raises(module.ProviderNotSupportSpeechToTextError): method(api, account, trial_app_chat) - def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_provider_not_init(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -744,11 +784,12 @@ class TestTrialChatAudioApi: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_quota_exceeded(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -761,7 +802,7 @@ class TestTrialChatAudioApi: class TestTrialChatTextApi: - def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_success(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -774,7 +815,7 @@ class TestTrialChatTextApi: assert result == {"audio": "base64_data"} - def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_app_config_broken(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -789,7 +830,7 @@ class TestTrialChatTextApi: with pytest.raises(module.AppUnavailableError): method(api, account, trial_app_chat) - def test_provider_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_provider_not_support(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -804,7 +845,7 @@ class TestTrialChatTextApi: with pytest.raises(module.ProviderNotSupportSpeechToTextError): method(api, account, trial_app_chat) - def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_audio_too_large(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -819,7 +860,7 @@ class TestTrialChatTextApi: with pytest.raises(module.AudioTooLargeError): method(api, account, trial_app_chat) - def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -834,7 +875,7 @@ class TestTrialChatTextApi: with pytest.raises(module.NoAudioUploadedError): method(api, account, trial_app_chat) - def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_provider_not_init(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -845,7 +886,7 @@ class TestTrialChatTextApi: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_quota_exceeded(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -856,7 +897,7 @@ class TestTrialChatTextApi: with pytest.raises(ProviderQuotaExceededError): method(api, account, trial_app_chat) - def test_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_model_not_support(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -867,7 +908,7 @@ class TestTrialChatTextApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(api, account, trial_app_chat) - def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_invoke_error(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -880,7 +921,7 @@ class TestTrialChatTextApi: class TestTrialAppWorkflowTaskStopApi: - def test_not_workflow_app(self, app: Flask, trial_app_chat: MagicMock) -> None: + def test_not_workflow_app(self, app: Flask, trial_app_chat): api = module.TrialAppWorkflowTaskStopApi() method = unwrap(api.post) @@ -888,7 +929,7 @@ class TestTrialAppWorkflowTaskStopApi: with pytest.raises(NotWorkflowAppError): method(api, trial_app_chat, str(uuid4())) - def test_success(self, app: Flask, trial_app_workflow: MagicMock) -> None: + def test_success(self, app: Flask, trial_app_workflow, account): api = module.TrialAppWorkflowTaskStopApi() method = unwrap(api.post) @@ -906,7 +947,7 @@ class TestTrialAppWorkflowTaskStopApi: class TestTrialSitApi: - def test_no_site(self, app: Flask) -> None: + def test_no_site(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) app_model = MagicMock() @@ -917,7 +958,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_archived_tenant(self, app: Flask) -> None: + def test_archived_tenant(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) @@ -932,7 +973,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_success(self, app: Flask) -> None: + def test_success(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) @@ -957,11 +998,12 @@ class TestTrialSitApi: class TestTrialChatAudioApiExceptionHandlers: - def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_provider_not_init(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -976,11 +1018,12 @@ class TestTrialChatAudioApiExceptionHandlers: with pytest.raises(ProviderNotInitializeError): method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_quota_exceeded(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -995,11 +1038,12 @@ class TestTrialChatAudioApiExceptionHandlers: with pytest.raises(ProviderQuotaExceededError): method(api, account, trial_app_chat) - def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_invoke_error(self, app: Flask, trial_app_chat, account): api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = _file_data() + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" with ( app.test_request_context( @@ -1016,7 +1060,7 @@ class TestTrialChatAudioApiExceptionHandlers: class TestTrialChatTextApiExceptionHandlers: - def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_app_config_broken(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) @@ -1031,7 +1075,7 @@ class TestTrialChatTextApiExceptionHandlers: with pytest.raises(module.AppUnavailableError): method(api, account, trial_app_chat) - def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: + def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account): api = module.TrialChatTextApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py index e8e005a1b83..55eee935606 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py @@ -1,3 +1,4 @@ +from datetime import datetime from inspect import unwrap from types import SimpleNamespace from unittest.mock import ANY, Mock @@ -46,12 +47,73 @@ def _snippet(**overrides) -> CustomizedSnippet: "name": "Snippet", "description": "Description", "type": snippets_module.SnippetType.NODE, + "version": 3, + "use_count": 7, + "is_published": True, + "icon_info": {"icon": "star", "icon_background": "#101828", "icon_type": "emoji"}, + "input_fields": '[{"label": "Question", "variable": "query", "type": "text-input"}]', "created_by": "account-1", + "created_at": datetime(2024, 1, 2, 3, 4, 5), + "updated_by": None, + "updated_at": datetime(2024, 1, 3, 4, 5, 6), } data.update(overrides) return CustomizedSnippet(**data) +def _patch_snippet_response_properties(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + CustomizedSnippet, + "tags", + property(lambda _snippet: [{"id": "tag-1", "name": "Search", "type": "snippet"}]), + ) + monkeypatch.setattr(CustomizedSnippet, "created_by_account", property(lambda _snippet: _account("account-1"))) + monkeypatch.setattr(CustomizedSnippet, "updated_by_account", property(lambda _snippet: None)) + + +def _expected_snippet_list_item(snippet: CustomizedSnippet) -> dict: + return { + "id": snippet.id, + "name": snippet.name, + "description": snippet.description, + "type": snippet.type, + "version": snippet.version, + "use_count": snippet.use_count, + "is_published": snippet.is_published, + "icon_info": snippet.icon_info, + "tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}], + "created_by": snippet.created_by, + "author_name": "Test User", + "created_at": int(snippet.created_at.timestamp()), + "updated_by": snippet.updated_by, + "updated_at": int(snippet.updated_at.timestamp()), + } + + +def _expected_snippet_response(snippet: CustomizedSnippet) -> dict: + return { + "id": snippet.id, + "name": snippet.name, + "description": snippet.description, + "type": snippet.type, + "version": snippet.version, + "use_count": snippet.use_count, + "is_published": snippet.is_published, + "icon_info": snippet.icon_info, + "graph": {}, + "input_fields": [{"label": "Question", "variable": "query", "type": "text-input"}], + "tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}], + "created_by": { + "id": "account-1", + "name": "Test User", + "email": "account-1@example.com", + }, + "created_at": int(snippet.created_at.timestamp()), + "updated_by": None, + "updated_at": int(snippet.updated_at.timestamp()), + } + + def test_normalize_snippet_list_query_args_sorts_indexed_values(): query_args = snippets_module.MultiDict( [ @@ -75,7 +137,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey tag_id = "11111111-1111-1111-1111-111111111111" get_snippets = Mock(return_value=(snippets, 1, False)) monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets) - monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}])) + _patch_snippet_response_properties(monkeypatch) api = snippets_module.CustomizedSnippetsApi() handler = unwrap(api.get) @@ -87,7 +149,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey assert status_code == 200 assert response == { - "data": [{"id": "snippet-1"}], + "data": [_expected_snippet_list_item(snippets[0])], "page": 2, "limit": 10, "total": 1, @@ -110,6 +172,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo snippet = _snippet() create_snippet = Mock(return_value=snippet) monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet) + _patch_snippet_response_properties(monkeypatch) monkeypatch.setattr( snippets_module.CreateSnippetPayload, "model_validate", @@ -124,7 +187,6 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo ) ), ) - monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"})) api = snippets_module.CustomizedSnippetsApi() handler = unwrap(api.post) @@ -137,7 +199,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo response, status_code = handler(api, "tenant-1", user) assert status_code == 201 - assert response == {"id": "snippet-1"} + assert response == _expected_snippet_response(snippet) assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE @@ -184,7 +246,7 @@ def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest. def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch): snippet = _snippet() monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) - monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"})) + _patch_snippet_response_properties(monkeypatch) api = snippets_module.CustomizedSnippetDetailApi() handler = unwrap(api.get) @@ -193,7 +255,7 @@ def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.Monk response, status_code = handler(api, "tenant-1", snippet_id="snippet-1") assert status_code == 200 - assert response == {"id": "snippet-1"} + assert response == _expected_snippet_response(snippet) def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch): @@ -230,7 +292,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet) monkeypatch.setattr(snippets_module, "Session", SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"})) + _patch_snippet_response_properties(monkeypatch) api = snippets_module.CustomizedSnippetDetailApi() handler = unwrap(api.patch) @@ -243,7 +305,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1") assert status_code == 200 - assert response == {"id": "snippet-1", "name": "New"} + assert response == _expected_snippet_response(updated_snippet) update_snippet.assert_called_once() assert update_snippet.call_args.kwargs["data"] == { "name": "New", diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index 47e9f51fb27..e9cd3410cb9 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -26,7 +26,9 @@ from controllers.console.workspace.workspace import ( WebappLogoWorkspaceApi, WorkspaceInfoApi, WorkspaceListApi, + WorkspaceLogoUploadResponse, WorkspacePermissionApi, + WorkspacePermissionResponse, ) from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -587,7 +589,8 @@ class TestWebappLogoWorkspaceApi: result, status = method(api, user) assert status == 201 - assert result["id"] == "file1" + assert result == {"id": "file1"} + assert WorkspaceLogoUploadResponse.model_validate(result).model_dump(mode="json") == {"id": "file1"} def test_filename_missing(self, app: Flask): api = WebappLogoWorkspaceApi() @@ -676,7 +679,7 @@ class TestWorkspaceInfoApi: patch("controllers.console.workspace.workspace.db.session.commit"), patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", - return_value={"name": "New Name"}, + return_value={"id": "t1", "name": "New Name"}, ), ): result = method(api, "t1") @@ -717,7 +720,13 @@ class TestWorkspacePermissionApi: result, status = method(api, "t1") assert status == 200 - assert result["workspace_id"] == "t1" + expected = { + "workspace_id": "t1", + "allow_member_invite": True, + "allow_owner_transfer": False, + } + assert result == expected + assert WorkspacePermissionResponse.model_validate(result).model_dump(mode="json") == expected def test_no_current_tenant(self, app: Flask): api = WorkspacePermissionApi() diff --git a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py index 362af883ed2..43cc2450db5 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py @@ -325,10 +325,12 @@ class TestPipelineRunApiEntity: def test_entity_missing_required_field(self): """Test entity raises on missing required field.""" with pytest.raises(ValueError): - PipelineRunApiEntity( - inputs={}, - datasource_type="online_document", - # missing datasource_info_list, start_node_id, etc. + PipelineRunApiEntity.model_validate( + { + "inputs": {}, + "datasource_type": "online_document", + # missing datasource_info_list, start_node_id, etc. + } ) @@ -382,8 +384,19 @@ class TestDatasourcePluginsApiGet: mock_dataset = Mock() mock_db.session.scalar.return_value = mock_dataset + datasource_plugins = [ + { + "node_id": "node-datasource-1", + "plugin_id": "plugin-a", + "provider_name": "provider-a", + "datasource_type": "online_document", + "title": "Online Docs", + "user_input_variables": [{"variable": "url", "label": "URL", "type": "text-input", "required": True}], + "credentials": [{"id": "cred-1", "name": "Default credential", "type": "oauth2", "is_default": True}], + } + ] mock_svc_instance = Mock() - mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}] + mock_svc_instance.get_datasource_plugins.return_value = datasource_plugins mock_svc_cls.return_value = mock_svc_instance with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"): @@ -391,11 +404,33 @@ class TestDatasourcePluginsApiGet: response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id) assert status == 200 - assert response == [{"name": "plugin_a"}] + assert response == datasource_plugins mock_svc_instance.get_datasource_plugins.assert_called_once_with( tenant_id=tenant_id, dataset_id=dataset_id, is_published=True ) + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") + def test_get_plugins_parses_false_is_published_query(self, mock_svc_cls, mock_db, app: Flask): + """Test false query string is parsed as boolean False.""" + tenant_id = str(uuid.uuid4()) + dataset_id = str(uuid.uuid4()) + + mock_db.session.scalar.return_value = Mock() + mock_svc_instance = Mock() + mock_svc_instance.get_datasource_plugins.return_value = [] + mock_svc_cls.return_value = mock_svc_instance + + with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=false"): + api = DatasourcePluginsApi() + response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id) + + assert status == 200 + assert response == [] + mock_svc_instance.get_datasource_plugins.assert_called_once_with( + tenant_id=tenant_id, dataset_id=dataset_id, is_published=False + ) + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") def test_get_plugins_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index 0caeae2cee4..1a602813eaa 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json from datetime import UTC, datetime from types import SimpleNamespace from typing import Any @@ -112,7 +111,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): chat_color_theme_inverted=False, copyright=None, privacy_policy=None, - custom_disclaimer=None, + custom_disclaimer="", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, @@ -138,7 +137,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): with app.test_request_context("/api/form/human_input/token-1", method="GET"): response = HumanInputFormApi().get("token-1") - body = json.loads(response.get_data(as_text=True)) + body = response assert set(body.keys()) == { "site", "form_content", @@ -167,7 +166,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): "description": "desc", "copyright": None, "privacy_policy": None, - "custom_disclaimer": None, + "custom_disclaimer": "", "default_language": "en", "prompt_public": False, "show_workflow_steps": True, @@ -256,7 +255,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a chat_color_theme_inverted=False, copyright=None, privacy_policy=None, - custom_disclaimer=None, + custom_disclaimer="", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, @@ -277,7 +276,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a with app.test_request_context("/api/form/human_input/token-1", method="GET"): response = HumanInputFormApi().get("token-1") - body = json.loads(response.get_data(as_text=True)) + body = response assert body["inputs"] == [input_config.model_dump(mode="json") for input_config in runtime_inputs] service_mock.resolve_form_inputs.assert_called_once_with(form) @@ -380,7 +379,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F chat_color_theme_inverted=False, copyright=None, privacy_policy=None, - custom_disclaimer=None, + custom_disclaimer="", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, @@ -403,7 +402,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F with app.test_request_context("/api/form/human_input/token-1", method="GET"): response = HumanInputFormApi().get("token-1") - body = json.loads(response.get_data(as_text=True)) + body = response assert set(body.keys()) == { "site", "form_content", @@ -432,7 +431,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F "description": "desc", "copyright": None, "privacy_policy": None, - "custom_disclaimer": None, + "custom_disclaimer": "", "default_language": "en", "prompt_public": False, "show_workflow_steps": True, diff --git a/api/tests/unit_tests/fields/test_snippet_fields.py b/api/tests/unit_tests/fields/test_snippet_fields.py index ad8ee6e8f0b..233c4e5da90 100644 --- a/api/tests/unit_tests/fields/test_snippet_fields.py +++ b/api/tests/unit_tests/fields/test_snippet_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from flask_restx import marshal - -from fields.snippet_fields import snippet_list_fields +from fields.snippet_fields import SnippetListItemResponse +from libs.helper import dump_response def test_snippet_list_fields_include_author_name() -> None: @@ -23,6 +22,6 @@ def test_snippet_list_fields_include_author_name() -> None: updated_at=None, ) - result = marshal(snippet, snippet_list_fields) + result = dump_response(SnippetListItemResponse, snippet) assert result["author_name"] == "Alice" diff --git a/packages/contracts/generated/api/console/agent/types.gen.ts b/packages/contracts/generated/api/console/agent/types.gen.ts index 0607d1662d7..7c8d3d052ad 100644 --- a/packages/contracts/generated/api/console/agent/types.gen.ts +++ b/packages/contracts/generated/api/console/agent/types.gen.ts @@ -269,6 +269,7 @@ export type MessageDetailResponse = { agent_thoughts?: Array annotation?: ConversationAnnotation | null annotation_hit_history?: ConversationAnnotationHitHistory | null + answer: string answer_tokens?: number | null conversation_id: string created_at?: number | null @@ -284,12 +285,11 @@ export type MessageDetailResponse = { } message?: JsonValue | null message_files?: Array - message_metadata_dict?: JsonValue | null message_tokens?: number | null + metadata?: JsonValue | null parent_message_id?: string | null provider_response_latency?: number | null query: string - re_sign_file_url_answer: string status: string workflow_run_id?: string | null } @@ -732,7 +732,6 @@ export type AgentThought = { created_at?: number | null files: Array id: string - message_chain_id?: string | null message_id: string observation?: string | null position: number @@ -752,8 +751,8 @@ export type ConversationAnnotation = { export type ConversationAnnotationHitHistory = { annotation_create_account?: SimpleAccount | null + annotation_id: string created_at?: number | null - id: string } export type HumanInputContent = { diff --git a/packages/contracts/generated/api/console/agent/zod.gen.ts b/packages/contracts/generated/api/console/agent/zod.gen.ts index 7d061e956f4..e0869366919 100644 --- a/packages/contracts/generated/api/console/agent/zod.gen.ts +++ b/packages/contracts/generated/api/console/agent/zod.gen.ts @@ -628,7 +628,6 @@ export const zAgentThought = z.object({ created_at: z.int().nullish(), files: z.array(z.string()), id: z.string(), - message_chain_id: z.string().nullish(), message_id: z.string(), observation: z.string().nullish(), position: z.int(), @@ -1060,8 +1059,8 @@ export const zConversationAnnotation = z.object({ */ export const zConversationAnnotationHitHistory = z.object({ annotation_create_account: zSimpleAccount.nullish(), + annotation_id: z.string(), created_at: z.int().nullish(), - id: z.string(), }) /** @@ -2039,6 +2038,7 @@ export const zMessageDetailResponse = z.object({ agent_thoughts: z.array(zAgentThought).optional(), annotation: zConversationAnnotation.nullish(), annotation_hit_history: zConversationAnnotationHitHistory.nullish(), + answer: z.string(), answer_tokens: z.int().nullish(), conversation_id: z.string(), created_at: z.int().nullish(), @@ -2052,12 +2052,11 @@ export const zMessageDetailResponse = z.object({ inputs: z.record(z.string(), zJsonValue), message: zJsonValue.nullish(), message_files: z.array(zMessageFile).optional(), - message_metadata_dict: zJsonValue.nullish(), message_tokens: z.int().nullish(), + metadata: zJsonValue.nullish(), parent_message_id: z.string().nullish(), provider_response_latency: z.number().nullish(), query: z.string(), - re_sign_file_url_answer: z.string(), status: z.string(), workflow_run_id: z.string().nullish(), }) diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index 0b46ade7963..9dd98f62fbc 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -433,6 +433,7 @@ export type MessageDetailResponse = { agent_thoughts?: Array annotation?: ConversationAnnotation | null annotation_hit_history?: ConversationAnnotationHitHistory | null + answer: string answer_tokens?: number | null conversation_id: string created_at?: number | null @@ -448,12 +449,11 @@ export type MessageDetailResponse = { } message?: JsonValue | null message_files?: Array - message_metadata_dict?: JsonValue | null message_tokens?: number | null + metadata?: JsonValue | null parent_message_id?: string | null provider_response_latency?: number | null query: string - re_sign_file_url_answer: string status: string workflow_run_id?: string | null } @@ -1440,7 +1440,6 @@ export type AgentThought = { created_at?: number | null files: Array id: string - message_chain_id?: string | null message_id: string observation?: string | null position: number @@ -1460,8 +1459,8 @@ export type ConversationAnnotation = { export type ConversationAnnotationHitHistory = { annotation_create_account?: SimpleAccount | null + annotation_id: string created_at?: number | null - id: string } export type HumanInputContent = { diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index 0103748ec7d..9737b2dfd45 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -1183,7 +1183,6 @@ export const zAgentThought = z.object({ created_at: z.int().nullish(), files: z.array(z.string()), id: z.string(), - message_chain_id: z.string().nullish(), message_id: z.string(), observation: z.string().nullish(), position: z.int(), @@ -2286,8 +2285,8 @@ export const zConversationPagination = z.object({ */ export const zConversationAnnotationHitHistory = z.object({ annotation_create_account: zSimpleAccount.nullish(), + annotation_id: z.string(), created_at: z.int().nullish(), - id: z.string(), }) /** @@ -3401,6 +3400,7 @@ export const zMessageDetailResponse = z.object({ agent_thoughts: z.array(zAgentThought).optional(), annotation: zConversationAnnotation.nullish(), annotation_hit_history: zConversationAnnotationHitHistory.nullish(), + answer: z.string(), answer_tokens: z.int().nullish(), conversation_id: z.string(), created_at: z.int().nullish(), @@ -3414,12 +3414,11 @@ export const zMessageDetailResponse = z.object({ inputs: z.record(z.string(), zJsonValue), message: zJsonValue.nullish(), message_files: z.array(zMessageFile).optional(), - message_metadata_dict: zJsonValue.nullish(), message_tokens: z.int().nullish(), + metadata: zJsonValue.nullish(), parent_message_id: z.string().nullish(), provider_response_latency: z.number().nullish(), query: z.string(), - re_sign_file_url_answer: z.string(), status: z.string(), workflow_run_id: z.string().nullish(), }) diff --git a/packages/contracts/generated/api/console/installed-apps/types.gen.ts b/packages/contracts/generated/api/console/installed-apps/types.gen.ts index a3e34caf590..d80a2883400 100644 --- a/packages/contracts/generated/api/console/installed-apps/types.gen.ts +++ b/packages/contracts/generated/api/console/installed-apps/types.gen.ts @@ -237,7 +237,6 @@ export type AgentThought = { created_at?: number | null files: Array id: string - message_chain_id?: string | null message_id: string observation?: string | null position: number diff --git a/packages/contracts/generated/api/console/installed-apps/zod.gen.ts b/packages/contracts/generated/api/console/installed-apps/zod.gen.ts index d911a043f3b..445de41d8df 100644 --- a/packages/contracts/generated/api/console/installed-apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/installed-apps/zod.gen.ts @@ -257,7 +257,6 @@ export const zAgentThought = z.object({ created_at: z.int().nullish(), files: z.array(z.string()), id: z.string(), - message_chain_id: z.string().nullish(), message_id: z.string(), observation: z.string().nullish(), position: z.int(), diff --git a/web/features/agent-v2/agent-detail/configure/components/preview/chat.tsx b/web/features/agent-v2/agent-detail/configure/components/preview/chat.tsx index a3127b4bd49..9ad0555d196 100644 --- a/web/features/agent-v2/agent-detail/configure/components/preview/chat.tsx +++ b/web/features/agent-v2/agent-detail/configure/components/preview/chat.tsx @@ -217,14 +217,8 @@ const toFeedback = (feedback: NonNullable[nu } } -type AgentDebugMessageWithLegacyAnswer = MessageDetailResponse & { - answer?: string | null -} - const getAgentDebugMessageAnswer = (message: MessageDetailResponse) => { - const legacyAnswer = (message as AgentDebugMessageWithLegacyAnswer).answer - - return message.re_sign_file_url_answer ?? legacyAnswer ?? '' + return message.answer ?? '' } function getFormattedAgentDebugChatTree(messages: MessageDetailResponse[]): ChatItemInTree[] {