mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 06:41:10 +08:00
fix(api): document message responses as serialized contracts
This commit is contained in:
parent
e937794831
commit
699b708802
@ -167,12 +167,16 @@ register_schema_models(
|
||||
ChatMessagesQuery,
|
||||
MessageFeedbackPayload,
|
||||
FeedbackExportQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
SimpleResultResponse,
|
||||
TextFileResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextFileResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
|
||||
@ -12,6 +12,7 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
|
||||
from controllers.common.fields import GeneratedAppResponse
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
register_schema_model,
|
||||
register_schema_models,
|
||||
@ -150,12 +151,11 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
query = query_params_from_request(DatasourcePluginsQuery)
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
|
||||
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=query.is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@ -13435,7 +13435,6 @@ Soft lifecycle state for Agent records.
|
||||
| created_at | integer | | No |
|
||||
| files | [ string ] | | Yes |
|
||||
| id | string | | Yes |
|
||||
| message_chain_id | string | | No |
|
||||
| message_id | string | | Yes |
|
||||
| observation | string | | No |
|
||||
| position | integer | | Yes |
|
||||
@ -14567,8 +14566,8 @@ Enum class for configurate method of provider model.
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| annotation_create_account | [SimpleAccount](#simpleaccount) | | No |
|
||||
| annotation_id | string | | Yes |
|
||||
| created_at | integer | | No |
|
||||
| id | string | | Yes |
|
||||
|
||||
#### ConversationDetail
|
||||
|
||||
@ -17072,6 +17071,7 @@ Enum class for large language model mode.
|
||||
| agent_thoughts | [ [AgentThought](#agentthought) ] | | No |
|
||||
| annotation | [ConversationAnnotation](#conversationannotation) | | No |
|
||||
| annotation_hit_history | [ConversationAnnotationHitHistory](#conversationannotationhithistory) | | No |
|
||||
| answer | string | | Yes |
|
||||
| answer_tokens | integer | | No |
|
||||
| conversation_id | string | | Yes |
|
||||
| created_at | integer | | No |
|
||||
@ -17085,12 +17085,11 @@ Enum class for large language model mode.
|
||||
| inputs | object | | Yes |
|
||||
| message | [JSONValue](#jsonvalue) | | No |
|
||||
| message_files | [ [MessageFile](#messagefile) ] | | No |
|
||||
| message_metadata_dict | [JSONValue](#jsonvalue) | | No |
|
||||
| message_tokens | integer | | No |
|
||||
| metadata | [JSONValue](#jsonvalue) | | No |
|
||||
| parent_message_id | string | | No |
|
||||
| provider_response_latency | number | | No |
|
||||
| query | string | | Yes |
|
||||
| re_sign_file_url_answer | string | | Yes |
|
||||
| status | string | | Yes |
|
||||
| workflow_run_id | string | | No |
|
||||
|
||||
|
||||
@ -607,7 +607,11 @@ class TestMiscApis:
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
recommended_plugins = {
|
||||
"installed_recommended_plugins": [{"id": "p1"}],
|
||||
"uninstalled_recommended_plugins": [{"id": "p2"}],
|
||||
}
|
||||
service.get_recommended_plugins.return_value = recommended_plugins
|
||||
user = make_account()
|
||||
tenant_id = "tenant-1"
|
||||
|
||||
@ -619,7 +623,7 @@ class TestMiscApis:
|
||||
),
|
||||
):
|
||||
result = method(api, tenant_id, user)
|
||||
assert result == [{"id": "p1"}]
|
||||
assert result == recommended_plugins
|
||||
service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id)
|
||||
|
||||
|
||||
@ -826,7 +830,7 @@ class TestRagPipelineByIdApi:
|
||||
result = method(api, pipeline, "old-workflow")
|
||||
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
assert result == ("", 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app: Flask) -> None:
|
||||
api = RagPipelineByIdApi()
|
||||
|
||||
@ -2,21 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
from controllers.web.site import AppSiteApi, WebAppSiteResponse, WebModelConfigResponse
|
||||
from models import Tenant, TenantStatus
|
||||
from models.model import App, AppMode, CustomizeTokenStrategy, Site
|
||||
from models.account import TenantCustomConfigDict
|
||||
from models.model import App, AppMode, AppModelConfig, CustomizeTokenStrategy, EndUser, Site
|
||||
from services.feature_service import FeatureModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers) -> Flask:
|
||||
def app(flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
@ -41,7 +42,23 @@ def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True
|
||||
|
||||
|
||||
def _create_site(db_session: Session, app_id: str) -> Site:
|
||||
site = Site(
|
||||
site = _site_model(app_id=app_id)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
return site
|
||||
|
||||
|
||||
def _end_user(tenant_id: str, app_id: str) -> EndUser:
|
||||
return EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="browser",
|
||||
session_id=f"session-{app_id}",
|
||||
)
|
||||
|
||||
|
||||
def _site_model(*, app_id: str) -> Site:
|
||||
return Site(
|
||||
app_id=app_id,
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
@ -51,31 +68,30 @@ def _create_site(db_session: Session, app_id: str) -> Site:
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
custom_disclaimer="",
|
||||
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
|
||||
code=f"code-{app_id[-6:]}",
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
return site
|
||||
|
||||
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None:
|
||||
def test_happy_path(self, mock_features: MagicMock, app: Flask, db_session_with_containers: Session) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
end_user = _end_user(tenant.id, app_model.id)
|
||||
mock_features.return_value = FeatureModel(can_replace_logo=False)
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
assert result["app_id"] == app_model.id
|
||||
assert result["end_user_id"] == end_user.id
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@ -83,51 +99,139 @@ class TestAppSiteApi:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
end_user = _end_user(tenant.id, app_model.id)
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
def test_archived_tenant_raises_forbidden(
|
||||
self, mock_features, app: Flask, db_session_with_containers: Session
|
||||
) -> None:
|
||||
def test_archived_tenant_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
end_user = _end_user(tenant.id, app_model.id)
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
class TestAppSiteInfo:
|
||||
def _tenant_model(*, plan: str = "basic", custom_config: TenantCustomConfigDict | None = None) -> Tenant:
|
||||
tenant = Tenant(name="test-tenant", plan=plan)
|
||||
tenant.custom_config_dict = custom_config or {}
|
||||
return tenant
|
||||
|
||||
|
||||
def _app_model(*, tenant: Tenant, enable_site: bool = True) -> App:
|
||||
app_model = App(
|
||||
tenant_id=tenant.id,
|
||||
mode=AppMode.CHAT,
|
||||
name="test-app",
|
||||
enable_site=enable_site,
|
||||
enable_api=True,
|
||||
)
|
||||
app_model.id = "app-test"
|
||||
return app_model
|
||||
|
||||
|
||||
class TestWebAppSiteResponse:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={})
|
||||
site_obj = SimpleNamespace()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
assert info.end_user_id == "eu-1"
|
||||
assert info.enable_site is True
|
||||
assert info.plan == "basic"
|
||||
assert info.can_replace_logo is False
|
||||
assert info.model_config is None
|
||||
|
||||
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
tenant = _tenant_model()
|
||||
app_model = _app_model(tenant=tenant)
|
||||
response = WebAppSiteResponse.from_app_site(
|
||||
tenant=tenant,
|
||||
app_model=app_model,
|
||||
site=_site_model(app_id=app_model.id),
|
||||
end_user_id="eu-1",
|
||||
can_replace_logo=False,
|
||||
)
|
||||
site_obj = SimpleNamespace()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
assert info.custom_config["remove_webapp_brand"] is True
|
||||
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
|
||||
assert response.app_id == app_model.id
|
||||
assert response.end_user_id == "eu-1"
|
||||
assert response.enable_site is True
|
||||
assert response.plan == "basic"
|
||||
assert response.can_replace_logo is False
|
||||
assert response.model_config_ is None
|
||||
assert response.custom_config is None
|
||||
assert response.site.custom_disclaimer == ""
|
||||
|
||||
def test_nullable_site_fields_preserve_none(self) -> None:
|
||||
tenant = _tenant_model()
|
||||
app_model = _app_model(tenant=tenant)
|
||||
site = _site_model(app_id=app_model.id)
|
||||
site.chat_color_theme = None
|
||||
site.icon_type = None
|
||||
site.icon = None
|
||||
site.icon_background = None
|
||||
site.description = None
|
||||
site.copyright = None
|
||||
site.privacy_policy = None
|
||||
|
||||
response = WebAppSiteResponse.from_app_site(
|
||||
tenant=tenant,
|
||||
app_model=app_model,
|
||||
site=site,
|
||||
end_user_id=None,
|
||||
can_replace_logo=False,
|
||||
)
|
||||
|
||||
dumped = response.model_dump(mode="json")
|
||||
assert dumped["end_user_id"] is None
|
||||
assert dumped["site"]["chat_color_theme"] is None
|
||||
assert dumped["site"]["icon_type"] is None
|
||||
assert dumped["site"]["icon"] is None
|
||||
assert dumped["site"]["icon_background"] is None
|
||||
assert dumped["site"]["description"] is None
|
||||
assert dumped["site"]["copyright"] is None
|
||||
assert dumped["site"]["privacy_policy"] is None
|
||||
assert dumped["site"]["custom_disclaimer"] == ""
|
||||
|
||||
@patch("controllers.web.site.dify_config.FILES_URL", "https://files.example.com")
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = _tenant_model(
|
||||
plan="pro",
|
||||
custom_config={"remove_webapp_brand": True, "replace_webapp_logo": "enabled"},
|
||||
)
|
||||
app_model = _app_model(tenant=tenant)
|
||||
response = WebAppSiteResponse.from_app_site(
|
||||
tenant=tenant,
|
||||
app_model=app_model,
|
||||
site=_site_model(app_id=app_model.id),
|
||||
end_user_id="eu-1",
|
||||
can_replace_logo=True,
|
||||
)
|
||||
|
||||
assert response.can_replace_logo is True
|
||||
assert response.custom_config is not None
|
||||
assert response.custom_config.remove_webapp_brand is True
|
||||
assert response.custom_config.replace_webapp_logo is not None
|
||||
assert "webapp-logo" in response.custom_config.replace_webapp_logo
|
||||
|
||||
|
||||
class TestWebModelConfigResponse:
|
||||
def test_serializes_internal_model_config_properties_to_public_keys(self) -> None:
|
||||
model_config = AppModelConfig(
|
||||
app_id="app-test",
|
||||
opening_statement="Hello",
|
||||
suggested_questions='["Question?"]',
|
||||
suggested_questions_after_answer='{"enabled": true}',
|
||||
more_like_this='{"enabled": false}',
|
||||
model='{"provider": "openai", "name": "gpt-4o", "mode": "chat"}',
|
||||
user_input_form='[{"text-input": {"label": "Name", "variable": "name", "required": true}}]',
|
||||
pre_prompt="System prompt",
|
||||
created_by="account-1",
|
||||
updated_by="account-1",
|
||||
)
|
||||
|
||||
dumped = WebModelConfigResponse.model_validate(model_config, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert dumped == {
|
||||
"opening_statement": "Hello",
|
||||
"suggested_questions": ["Question?"],
|
||||
"suggested_questions_after_answer": {"enabled": True},
|
||||
"more_like_this": {"enabled": False},
|
||||
"model": {"provider": "openai", "name": "gpt-4o", "mode": "chat"},
|
||||
"user_input_form": [{"text-input": {"label": "Name", "variable": "name", "required": True}}],
|
||||
"pre_prompt": "System prompt",
|
||||
}
|
||||
|
||||
@ -77,9 +77,11 @@ def test_human_input_preview_delegates_to_service(
|
||||
|
||||
preview_payload = {
|
||||
"form_id": "node-42",
|
||||
"node_id": "node-42",
|
||||
"node_title": "Human Input",
|
||||
"form_content": "<div>example</div>",
|
||||
"inputs": [{"name": "topic"}],
|
||||
"actions": [{"id": "continue"}],
|
||||
"actions": [{"id": "continue", "title": "Continue"}],
|
||||
}
|
||||
service_instance = MagicMock()
|
||||
service_instance.get_human_input_form_preview.return_value = preview_payload
|
||||
@ -88,7 +90,15 @@ def test_human_input_preview_delegates_to_service(
|
||||
with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-42")
|
||||
|
||||
assert response == preview_payload
|
||||
assert response == {
|
||||
**preview_payload,
|
||||
"TYPE": "human_input_required",
|
||||
"actions": [{"id": "continue", "title": "Continue", "button_style": "default"}],
|
||||
"resolved_default_values": {},
|
||||
"display_in_ui": False,
|
||||
"form_token": None,
|
||||
"expiration_time": None,
|
||||
}
|
||||
service_instance.get_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -16,6 +18,7 @@ from controllers.console.datasets.external import (
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
@ -38,28 +41,183 @@ def current_user() -> Account:
|
||||
return user
|
||||
|
||||
|
||||
def _external_api_dict(api_id: str = "api-1") -> dict:
|
||||
return {
|
||||
"id": api_id,
|
||||
"tenant_id": "tenant-1",
|
||||
"name": f"External API {api_id}",
|
||||
"description": f"Description for {api_id}",
|
||||
"settings": {
|
||||
"endpoint": f"https://external.example.com/{api_id}",
|
||||
"api_key": "secret",
|
||||
"headers": {"X-Source": "unit-test"},
|
||||
"timeout": 30,
|
||||
},
|
||||
"dataset_bindings": [
|
||||
{"id": f"dataset-{api_id}", "name": f"Dataset {api_id}"},
|
||||
],
|
||||
"created_by": "user-1",
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
def _external_api_object(api_id: str = "api-1") -> SimpleNamespace:
|
||||
payload = _external_api_dict(api_id)
|
||||
return SimpleNamespace(
|
||||
**{
|
||||
**payload,
|
||||
"dataset_bindings": [SimpleNamespace(**binding) for binding in payload["dataset_bindings"]],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _expected_dataset_detail_payload() -> dict[str, Any]:
|
||||
return {
|
||||
"id": "dataset-1",
|
||||
"name": "Support knowledge",
|
||||
"description": "External support articles",
|
||||
"provider": "external",
|
||||
"permission": "only_me",
|
||||
"data_source_type": "external",
|
||||
"indexing_technique": "economy",
|
||||
"app_count": 2,
|
||||
"document_count": 7,
|
||||
"word_count": 2048,
|
||||
"created_by": "user-1",
|
||||
"author_name": "Test User",
|
||||
"created_at": 1710000000,
|
||||
"updated_by": "user-2",
|
||||
"updated_at": 1710003600,
|
||||
"embedding_model": None,
|
||||
"embedding_model_provider": None,
|
||||
"embedding_available": False,
|
||||
"retrieval_model_dict": {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"reranking_mode": None,
|
||||
"reranking_model": {"reranking_provider_name": None, "reranking_model_name": None},
|
||||
"weights": None,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.5,
|
||||
},
|
||||
"summary_index_setting": {
|
||||
"enable": True,
|
||||
"model_name": "summary-model",
|
||||
"model_provider_name": "provider-a",
|
||||
"summary_prompt": "Summarize this.",
|
||||
},
|
||||
"tags": [{"id": "tag-1", "name": "Support", "type": "knowledge"}],
|
||||
"doc_form": "text_model",
|
||||
"external_knowledge_info": {
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_api_name": "External API api-1",
|
||||
"external_knowledge_api_endpoint": "https://external.example.com/api-1",
|
||||
},
|
||||
"external_retrieval_model": {
|
||||
"top_k": 4,
|
||||
"score_threshold": 0.5,
|
||||
"score_threshold_enabled": True,
|
||||
},
|
||||
"doc_metadata": [{"id": "metadata-1", "name": "source", "type": "string"}],
|
||||
"built_in_field_enabled": True,
|
||||
"pipeline_id": None,
|
||||
"runtime_mode": "external",
|
||||
"chunk_structure": "general",
|
||||
"icon_info": {
|
||||
"icon_type": "emoji",
|
||||
"icon": "book",
|
||||
"icon_background": "#FFF4ED",
|
||||
"icon_url": None,
|
||||
},
|
||||
"is_published": True,
|
||||
"total_documents": 7,
|
||||
"total_available_documents": 6,
|
||||
"enable_api": True,
|
||||
"is_multimodal": False,
|
||||
"maintainer": None,
|
||||
"permission_keys": [],
|
||||
}
|
||||
|
||||
|
||||
def _dataset_detail_object() -> SimpleNamespace:
|
||||
payload = _expected_dataset_detail_payload()
|
||||
return SimpleNamespace(
|
||||
**{
|
||||
**payload,
|
||||
"summary_index_setting": SimpleNamespace(**payload["summary_index_setting"]),
|
||||
"tags": [SimpleNamespace(**tag) for tag in payload["tags"]],
|
||||
"external_knowledge_info": SimpleNamespace(**payload["external_knowledge_info"]),
|
||||
"external_retrieval_model": SimpleNamespace(**payload["external_retrieval_model"]),
|
||||
"doc_metadata": [SimpleNamespace(**item) for item in payload["doc_metadata"]],
|
||||
"icon_info": SimpleNamespace(**payload["icon_info"]),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
api_item = MagicMock()
|
||||
api_item.to_dict.return_value = {"id": "1"}
|
||||
api_item = _external_api_object("api-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
app.test_request_context("/?page=2&limit=1&keyword=vector"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
return_value=([api_item], 3),
|
||||
) as get_external_knowledge_apis,
|
||||
):
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
assert resp == {
|
||||
"data": [_external_api_dict("api-1")],
|
||||
"has_more": True,
|
||||
"limit": 1,
|
||||
"total": 3,
|
||||
"page": 2,
|
||||
}
|
||||
get_external_knowledge_apis.assert_called_once_with(2, 1, "tenant-1", "vector")
|
||||
|
||||
def test_post_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Vendor Search",
|
||||
"settings": {
|
||||
"endpoint": "https://external.example.com/search",
|
||||
"api_key": "secret",
|
||||
"headers": {"X-Source": "unit-test"},
|
||||
"timeout": 30,
|
||||
},
|
||||
}
|
||||
created = _external_api_object("api-created")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list,
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_knowledge_api",
|
||||
return_value=created,
|
||||
) as create_external_knowledge_api,
|
||||
):
|
||||
resp, status = method(api, "tenant-1", current_user)
|
||||
|
||||
assert status == 201
|
||||
assert resp == _external_api_dict("api-created")
|
||||
validate_api_list.assert_called_once_with(payload["settings"])
|
||||
create_external_knowledge_api.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
args=payload,
|
||||
)
|
||||
|
||||
def test_post_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
@ -97,6 +255,25 @@ class TestExternalApiTemplateListApi:
|
||||
|
||||
|
||||
class TestExternalApiTemplateApi:
|
||||
def test_get_success_returns_template_contract(self, app: Flask):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
template = _external_api_object("api-detail")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_api",
|
||||
return_value=template,
|
||||
) as get_external_knowledge_api,
|
||||
):
|
||||
resp, status = method(api, "tenant-1", "api-detail")
|
||||
|
||||
assert status == 200
|
||||
assert resp == _external_api_dict("api-detail")
|
||||
get_external_knowledge_api.assert_called_once_with("api-detail", "tenant-1")
|
||||
|
||||
def test_get_not_found(self, app: Flask):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
@ -112,6 +289,42 @@ class TestExternalApiTemplateApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "tenant-1", "api-id")
|
||||
|
||||
def test_patch_success_uses_validated_payload_and_returns_template(self, app: Flask, current_user: Account):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Updated API",
|
||||
"settings": {
|
||||
"endpoint": "https://external.example.com/updated",
|
||||
"api_key": "new-secret",
|
||||
"headers": {"X-Version": "2"},
|
||||
},
|
||||
}
|
||||
updated = _external_api_object("api-updated")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list") as validate_api_list,
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"update_external_knowledge_api",
|
||||
return_value=updated,
|
||||
) as update_external_knowledge_api,
|
||||
):
|
||||
resp, status = method(api, "tenant-1", current_user, "api-updated")
|
||||
|
||||
assert status == 200
|
||||
assert resp == _external_api_dict("api-updated")
|
||||
validate_api_list.assert_called_once_with(payload["settings"])
|
||||
update_external_knowledge_api.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
external_knowledge_api_id="api-updated",
|
||||
args=payload,
|
||||
)
|
||||
|
||||
def test_delete_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
|
||||
@ -149,45 +362,37 @@ class TestExternalDatasetCreateApi:
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"name": "Support knowledge",
|
||||
"description": "External support articles",
|
||||
"external_retrieval_model": {
|
||||
"top_k": 4,
|
||||
"score_threshold": 0.5,
|
||||
"score_threshold_enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
dataset.embedding_available = False
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset.is_published = False
|
||||
dataset.enable_api = False
|
||||
dataset.enable_qa = False
|
||||
dataset.enable_vector_store = False
|
||||
dataset.vector_store_setting = None
|
||||
dataset.is_multimodal = False
|
||||
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset.tags = []
|
||||
dataset.external_knowledge_info = None
|
||||
dataset.external_retrieval_model = None
|
||||
dataset.doc_metadata = []
|
||||
dataset.icon_info = None
|
||||
dataset.permission_keys = []
|
||||
|
||||
dataset.summary_index_setting = MagicMock()
|
||||
dataset.summary_index_setting.enable = False
|
||||
dataset = _dataset_detail_object()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
) as create_external_dataset,
|
||||
):
|
||||
_, status = method(api, "tenant-1", current_user)
|
||||
resp, status = method(api, "tenant-1", current_user)
|
||||
|
||||
assert status == 201
|
||||
assert resp == _expected_dataset_detail_payload()
|
||||
create_external_dataset.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
args=payload,
|
||||
)
|
||||
|
||||
def test_create_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
@ -228,24 +433,58 @@ class TestExternalKnowledgeHitTestingApi:
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"query": "hello"}
|
||||
payload = {
|
||||
"query": "hello",
|
||||
"external_retrieval_model": {
|
||||
"top_k": 3,
|
||||
"score_threshold": 0.25,
|
||||
"score_threshold_enabled": True,
|
||||
},
|
||||
"metadata_filtering_conditions": {
|
||||
"logical_operator": "and",
|
||||
"conditions": [{"name": "source", "comparison_operator": "contains", "value": "external"}],
|
||||
},
|
||||
}
|
||||
|
||||
dataset = MagicMock()
|
||||
retrieve_response = {
|
||||
"query": {"content": "hello"},
|
||||
"records": [
|
||||
{
|
||||
"content": "answer",
|
||||
"title": "doc",
|
||||
"score": 0.9,
|
||||
"metadata": {"source": "external", "page": 2},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
||||
patch.object(DatasetService, "check_dataset_permission"),
|
||||
patch.object(DatasetService, "check_dataset_permission") as check_dataset_permission,
|
||||
patch.object(HitTestingService, "hit_testing_args_check") as hit_testing_args_check,
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
"external_retrieve",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
return_value=retrieve_response,
|
||||
) as external_retrieve,
|
||||
patch("controllers.console.datasets.external.dump_response", side_effect=lambda _model, value: value),
|
||||
):
|
||||
resp = method(api, current_user, "dataset-id")
|
||||
|
||||
assert resp["ok"] is True
|
||||
assert resp == retrieve_response
|
||||
check_dataset_permission.assert_called_once_with(dataset, current_user)
|
||||
hit_testing_args_check.assert_called_once_with(payload)
|
||||
external_retrieve.assert_called_once_with(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query="hello",
|
||||
account=current_user,
|
||||
external_retrieval_model=payload["external_retrieval_model"],
|
||||
metadata_filtering_conditions=payload["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
|
||||
class TestBedrockRetrievalApi:
|
||||
@ -254,24 +493,44 @@ class TestBedrockRetrievalApi:
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "hello",
|
||||
"knowledge_id": "kid",
|
||||
"retrieval_setting": {"top_k": 5, "score_threshold": 0.72},
|
||||
"query": "hello bedrock",
|
||||
"knowledge_id": "knowledge-base-1",
|
||||
}
|
||||
retrieval_response = {
|
||||
"records": [
|
||||
{
|
||||
"metadata": {"source": "bedrock", "uri": "s3://bucket/doc.txt"},
|
||||
"score": 0.8,
|
||||
"title": "doc",
|
||||
"content": "answer",
|
||||
},
|
||||
{
|
||||
"metadata": {"source": "bedrock", "uri": "s3://bucket/other.txt"},
|
||||
"score": 0.65,
|
||||
"title": None,
|
||||
"content": None,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetTestService,
|
||||
"knowledge_retrieval",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
return_value=retrieval_response,
|
||||
) as knowledge_retrieval,
|
||||
):
|
||||
resp, status = method()
|
||||
|
||||
assert status == 200
|
||||
assert resp["ok"] is True
|
||||
assert resp == retrieval_response
|
||||
retrieval_setting, query, knowledge_id = knowledge_retrieval.call_args.args
|
||||
assert retrieval_setting.model_dump() == payload["retrieval_setting"]
|
||||
assert query == "hello bedrock"
|
||||
assert knowledge_id == "knowledge-base-1"
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApiAdvanced:
|
||||
@ -297,10 +556,10 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
|
||||
templates = [_external_api_object(f"api-{i}") for i in range(3)]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
app.test_request_context("/?page=2&limit=3"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
@ -309,9 +568,14 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
assert resp == {
|
||||
"data": [_external_api_dict(f"api-{i}") for i in range(3)],
|
||||
"has_more": True,
|
||||
"limit": 3,
|
||||
"total": 25,
|
||||
"page": 2,
|
||||
}
|
||||
get_external_knowledge_apis.assert_called_once_with(2, 3, "tenant-1", None)
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
@ -374,15 +638,46 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
|
||||
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission") as check_permission,
|
||||
patch("controllers.console.datasets.external.HitTestingService.hit_testing_args_check") as args_check,
|
||||
patch(
|
||||
"controllers.console.datasets.external.HitTestingService.external_retrieve",
|
||||
return_value={"results": []},
|
||||
),
|
||||
return_value={
|
||||
"query": {"content": "test query"},
|
||||
"records": [
|
||||
{
|
||||
"content": None,
|
||||
"title": "metadata-only",
|
||||
"score": None,
|
||||
"metadata": {"status": "active"},
|
||||
}
|
||||
],
|
||||
},
|
||||
) as external_retrieve,
|
||||
):
|
||||
resp = method(api, current_user, "ds-1")
|
||||
|
||||
assert resp["results"] == []
|
||||
assert resp == {
|
||||
"query": {"content": "test query"},
|
||||
"records": [
|
||||
{
|
||||
"content": None,
|
||||
"title": "metadata-only",
|
||||
"score": None,
|
||||
"metadata": {"status": "active"},
|
||||
}
|
||||
],
|
||||
}
|
||||
check_permission.assert_called_once_with(dataset, current_user)
|
||||
args_check.assert_called_once_with(payload)
|
||||
external_retrieve.assert_called_once_with(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=current_user,
|
||||
external_retrieval_model={"type": "bm25"},
|
||||
metadata_filtering_conditions={"status": "active"},
|
||||
)
|
||||
|
||||
|
||||
class TestBedrockRetrievalApiAdvanced:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from inspect import unwrap as inspect_unwrap
|
||||
from inspect import unwrap
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -35,24 +35,16 @@ from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
unwrap: Any = inspect_unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account() -> Account:
|
||||
acc = Account(name="User", email="user@example.com")
|
||||
def account():
|
||||
acc = MagicMock(spec=Account)
|
||||
acc.id = "u1"
|
||||
return acc
|
||||
|
||||
|
||||
def _file_data() -> Any:
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
return file_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trial_app_chat() -> MagicMock:
|
||||
def trial_app_chat():
|
||||
app = MagicMock()
|
||||
app.id = "a-chat"
|
||||
app.mode = AppMode.CHAT
|
||||
@ -60,7 +52,7 @@ def trial_app_chat() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trial_app_completion() -> MagicMock:
|
||||
def trial_app_completion():
|
||||
app = MagicMock()
|
||||
app.id = "a-comp"
|
||||
app.mode = AppMode.COMPLETION
|
||||
@ -68,7 +60,7 @@ def trial_app_completion() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trial_app_workflow() -> MagicMock:
|
||||
def trial_app_workflow():
|
||||
app = MagicMock()
|
||||
app.id = "a-workflow"
|
||||
app.mode = AppMode.WORKFLOW
|
||||
@ -76,7 +68,7 @@ def trial_app_workflow() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_parameters() -> dict[str, object]:
|
||||
def valid_parameters():
|
||||
return {
|
||||
"user_input_form": [],
|
||||
"system_parameters": {},
|
||||
@ -92,13 +84,54 @@ def valid_parameters() -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def test_trial_workflow_uses_trial_scoped_simple_account_model() -> None:
|
||||
assert module.simple_account_model.name == "TrialSimpleAccount"
|
||||
assert hasattr(module.simple_account_model, "items")
|
||||
def test_trial_workflow_registers_normalized_simple_account_response_model():
|
||||
assert "SimpleAccountResponse" in module.console_ns.models
|
||||
|
||||
|
||||
def _response_model_name(entry: object) -> str:
|
||||
assert isinstance(entry, tuple)
|
||||
assert len(entry) >= 2
|
||||
model = entry[1]
|
||||
name = getattr(model, "name", None)
|
||||
assert isinstance(name, str)
|
||||
return name
|
||||
|
||||
|
||||
def test_trial_endpoints_keep_response_and_query_docs():
|
||||
untyped_generated_response_views = [
|
||||
module.TrialAppWorkflowRunApi.post,
|
||||
module.TrialChatApi.post,
|
||||
module.TrialCompletionApi.post,
|
||||
]
|
||||
for view in untyped_generated_response_views:
|
||||
apidoc = getattr(view, "__apidoc__", {})
|
||||
assert apidoc.get("responses", {})["200"] == ("Success", None, {})
|
||||
|
||||
cases = [
|
||||
(module.TrialMessageSuggestedQuestionApi.get, module.SuggestedQuestionsResponse.__name__),
|
||||
(module.TrialChatAudioApi.post, module.AudioTranscriptResponse.__name__),
|
||||
(module.TrialChatTextApi.post, module.AudioBinaryResponse.__name__),
|
||||
(module.TrialSitApi.get, module.SiteResponse.__name__),
|
||||
(module.TrialAppParameterApi.get, module.ParametersResponse.__name__),
|
||||
(module.AppApi.get, module.AppDetailWithSite.__name__),
|
||||
(module.AppWorkflowApi.get, module.WorkflowResponse.__name__),
|
||||
(module.DatasetListApi.get, module.TrialDatasetListResponse.__name__),
|
||||
]
|
||||
|
||||
for view, model_name in cases:
|
||||
apidoc = getattr(view, "__apidoc__", {})
|
||||
responses = apidoc.get("responses", {})
|
||||
assert _response_model_name(responses["200"]) == model_name
|
||||
|
||||
dataset_params = module.DatasetListApi.get.__apidoc__["params"]
|
||||
assert dataset_params["ids"]["in"] == "query"
|
||||
assert dataset_params["ids"]["type"] == "array"
|
||||
assert dataset_params["page"]["default"] == 1
|
||||
assert dataset_params["limit"]["default"] == 20
|
||||
|
||||
|
||||
class TestTrialAppWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app: Flask, account: Account) -> None:
|
||||
def test_not_workflow_app(self, app: Flask, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -106,7 +139,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(api, account, MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
def test_success(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -119,7 +152,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -134,7 +167,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -149,7 +182,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_model_not_support(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_model_not_support(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -164,7 +197,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_invoke_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_invoke_error(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -179,7 +212,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -194,7 +227,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_value_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_value_error(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -209,7 +242,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, account, trial_app_workflow)
|
||||
|
||||
def test_workflow_generic_exception(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None:
|
||||
def test_workflow_generic_exception(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -226,7 +259,7 @@ class TestTrialAppWorkflowRunApi:
|
||||
|
||||
|
||||
class TestTrialChatApi:
|
||||
def test_not_chat_app(self, app: Flask, account: Account) -> None:
|
||||
def test_not_chat_app(self, app: Flask, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -234,7 +267,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(api, account, MagicMock(mode="completion"))
|
||||
|
||||
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -247,7 +280,7 @@ class TestTrialChatApi:
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -262,7 +295,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_conversation_completed(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_conversation_completed(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -277,7 +310,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_app_config_broken(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -292,7 +325,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_provider_not_init(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -307,7 +340,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_quota_exceeded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -322,7 +355,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_model_not_support(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -337,7 +370,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_invoke_error(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -352,7 +385,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_rate_limit_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_rate_limit_error(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -367,7 +400,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_value_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_value_error(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -382,7 +415,7 @@ class TestTrialChatApi:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_chat_generic_exception(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_chat_generic_exception(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -399,7 +432,7 @@ class TestTrialChatApi:
|
||||
|
||||
|
||||
class TestTrialCompletionApi:
|
||||
def test_not_completion_app(self, app: Flask, account: Account) -> None:
|
||||
def test_not_completion_app(self, app: Flask, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -407,7 +440,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(api, account, MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
def test_success(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -420,7 +453,7 @@ class TestTrialCompletionApi:
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_completion_app_config_broken(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_app_config_broken(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -435,7 +468,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_provider_not_init(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_provider_not_init(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -450,7 +483,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_quota_exceeded(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_quota_exceeded(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -465,7 +498,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_model_not_support(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_model_not_support(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -480,7 +513,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_invoke_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_invoke_error(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -495,7 +528,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_rate_limit_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_rate_limit_error(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -510,7 +543,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(InternalServerError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_value_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_value_error(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -525,7 +558,7 @@ class TestTrialCompletionApi:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, account, trial_app_completion)
|
||||
|
||||
def test_completion_generic_exception(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None:
|
||||
def test_completion_generic_exception(self, app: Flask, trial_app_completion, account):
|
||||
api = module.TrialCompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -542,7 +575,7 @@ class TestTrialCompletionApi:
|
||||
|
||||
|
||||
class TestTrialMessageSuggestedQuestionApi:
|
||||
def test_not_chat_app(self, app: Flask, account: Account) -> None:
|
||||
def test_not_chat_app(self, app: Flask, account):
|
||||
api = module.TrialMessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -550,7 +583,7 @@ class TestTrialMessageSuggestedQuestionApi:
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(api, account, MagicMock(mode="completion"), str(uuid4()))
|
||||
|
||||
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialMessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -566,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi:
|
||||
|
||||
assert result == {"data": ["q1", "q2"]}
|
||||
|
||||
def test_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_conversation_not_exists(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialMessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -583,14 +616,14 @@ class TestTrialMessageSuggestedQuestionApi:
|
||||
|
||||
|
||||
class TestTrialAppParameterApi:
|
||||
def test_app_unavailable(self) -> None:
|
||||
def test_app_unavailable(self):
|
||||
api = module.TrialAppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(api, None)
|
||||
|
||||
def test_success_non_workflow(self, valid_parameters: dict[str, object]) -> None:
|
||||
def test_success_non_workflow(self, valid_parameters):
|
||||
api = module.TrialAppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -617,11 +650,12 @@ class TestTrialAppParameterApi:
|
||||
|
||||
|
||||
class TestTrialChatAudioApi:
|
||||
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -634,11 +668,12 @@ class TestTrialChatAudioApi:
|
||||
|
||||
assert result == {"text": "hello"}
|
||||
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -653,11 +688,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(module.AppUnavailableError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -672,11 +708,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(module.NoAudioUploadedError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_audio_too_large(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -691,11 +728,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(module.AudioTooLargeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -710,11 +748,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(module.UnsupportedAudioTypeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_provider_not_support_tts(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_provider_not_support_tts(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -729,11 +768,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(module.ProviderNotSupportSpeechToTextError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -744,11 +784,12 @@ class TestTrialChatAudioApi:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -761,7 +802,7 @@ class TestTrialChatAudioApi:
|
||||
|
||||
|
||||
class TestTrialChatTextApi:
|
||||
def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_success(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -774,7 +815,7 @@ class TestTrialChatTextApi:
|
||||
|
||||
assert result == {"audio": "base64_data"}
|
||||
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -789,7 +830,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(module.AppUnavailableError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_provider_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_provider_not_support(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -804,7 +845,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(module.ProviderNotSupportSpeechToTextError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_audio_too_large(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -819,7 +860,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(module.AudioTooLargeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -834,7 +875,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(module.NoAudioUploadedError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -845,7 +886,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -856,7 +897,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_model_not_support(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -867,7 +908,7 @@ class TestTrialChatTextApi:
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_invoke_error(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -880,7 +921,7 @@ class TestTrialChatTextApi:
|
||||
|
||||
|
||||
class TestTrialAppWorkflowTaskStopApi:
|
||||
def test_not_workflow_app(self, app: Flask, trial_app_chat: MagicMock) -> None:
|
||||
def test_not_workflow_app(self, app: Flask, trial_app_chat):
|
||||
api = module.TrialAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -888,7 +929,7 @@ class TestTrialAppWorkflowTaskStopApi:
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(api, trial_app_chat, str(uuid4()))
|
||||
|
||||
def test_success(self, app: Flask, trial_app_workflow: MagicMock) -> None:
|
||||
def test_success(self, app: Flask, trial_app_workflow, account):
|
||||
api = module.TrialAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -906,7 +947,7 @@ class TestTrialAppWorkflowTaskStopApi:
|
||||
|
||||
|
||||
class TestTrialSitApi:
|
||||
def test_no_site(self, app: Flask) -> None:
|
||||
def test_no_site(self, app: Flask):
|
||||
api = module.TrialSitApi()
|
||||
method = unwrap(api.get)
|
||||
app_model = MagicMock()
|
||||
@ -917,7 +958,7 @@ class TestTrialSitApi:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
def test_archived_tenant(self, app: Flask) -> None:
|
||||
def test_archived_tenant(self, app: Flask):
|
||||
api = module.TrialSitApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -932,7 +973,7 @@ class TestTrialSitApi:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
def test_success(self, app: Flask) -> None:
|
||||
def test_success(self, app: Flask):
|
||||
api = module.TrialSitApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -957,11 +998,12 @@ class TestTrialSitApi:
|
||||
|
||||
|
||||
class TestTrialChatAudioApiExceptionHandlers:
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_provider_not_init(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -976,11 +1018,12 @@ class TestTrialChatAudioApiExceptionHandlers:
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_quota_exceeded(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -995,11 +1038,12 @@ class TestTrialChatAudioApiExceptionHandlers:
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_invoke_error(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatAudioApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file_data = _file_data()
|
||||
file_data: Any = BytesIO(b"fake audio data")
|
||||
file_data.filename = "test.wav"
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -1016,7 +1060,7 @@ class TestTrialChatAudioApiExceptionHandlers:
|
||||
|
||||
|
||||
class TestTrialChatTextApiExceptionHandlers:
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_app_config_broken(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -1031,7 +1075,7 @@ class TestTrialChatTextApiExceptionHandlers:
|
||||
with pytest.raises(module.AppUnavailableError):
|
||||
method(api, account, trial_app_chat)
|
||||
|
||||
def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None:
|
||||
def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account):
|
||||
api = module.TrialChatTextApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock
|
||||
@ -46,12 +47,73 @@ def _snippet(**overrides) -> CustomizedSnippet:
|
||||
"name": "Snippet",
|
||||
"description": "Description",
|
||||
"type": snippets_module.SnippetType.NODE,
|
||||
"version": 3,
|
||||
"use_count": 7,
|
||||
"is_published": True,
|
||||
"icon_info": {"icon": "star", "icon_background": "#101828", "icon_type": "emoji"},
|
||||
"input_fields": '[{"label": "Question", "variable": "query", "type": "text-input"}]',
|
||||
"created_by": "account-1",
|
||||
"created_at": datetime(2024, 1, 2, 3, 4, 5),
|
||||
"updated_by": None,
|
||||
"updated_at": datetime(2024, 1, 3, 4, 5, 6),
|
||||
}
|
||||
data.update(overrides)
|
||||
return CustomizedSnippet(**data)
|
||||
|
||||
|
||||
def _patch_snippet_response_properties(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
CustomizedSnippet,
|
||||
"tags",
|
||||
property(lambda _snippet: [{"id": "tag-1", "name": "Search", "type": "snippet"}]),
|
||||
)
|
||||
monkeypatch.setattr(CustomizedSnippet, "created_by_account", property(lambda _snippet: _account("account-1")))
|
||||
monkeypatch.setattr(CustomizedSnippet, "updated_by_account", property(lambda _snippet: None))
|
||||
|
||||
|
||||
def _expected_snippet_list_item(snippet: CustomizedSnippet) -> dict:
|
||||
return {
|
||||
"id": snippet.id,
|
||||
"name": snippet.name,
|
||||
"description": snippet.description,
|
||||
"type": snippet.type,
|
||||
"version": snippet.version,
|
||||
"use_count": snippet.use_count,
|
||||
"is_published": snippet.is_published,
|
||||
"icon_info": snippet.icon_info,
|
||||
"tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}],
|
||||
"created_by": snippet.created_by,
|
||||
"author_name": "Test User",
|
||||
"created_at": int(snippet.created_at.timestamp()),
|
||||
"updated_by": snippet.updated_by,
|
||||
"updated_at": int(snippet.updated_at.timestamp()),
|
||||
}
|
||||
|
||||
|
||||
def _expected_snippet_response(snippet: CustomizedSnippet) -> dict:
|
||||
return {
|
||||
"id": snippet.id,
|
||||
"name": snippet.name,
|
||||
"description": snippet.description,
|
||||
"type": snippet.type,
|
||||
"version": snippet.version,
|
||||
"use_count": snippet.use_count,
|
||||
"is_published": snippet.is_published,
|
||||
"icon_info": snippet.icon_info,
|
||||
"graph": {},
|
||||
"input_fields": [{"label": "Question", "variable": "query", "type": "text-input"}],
|
||||
"tags": [{"id": "tag-1", "name": "Search", "type": "snippet"}],
|
||||
"created_by": {
|
||||
"id": "account-1",
|
||||
"name": "Test User",
|
||||
"email": "account-1@example.com",
|
||||
},
|
||||
"created_at": int(snippet.created_at.timestamp()),
|
||||
"updated_by": None,
|
||||
"updated_at": int(snippet.updated_at.timestamp()),
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_snippet_list_query_args_sorts_indexed_values():
|
||||
query_args = snippets_module.MultiDict(
|
||||
[
|
||||
@ -75,7 +137,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey
|
||||
tag_id = "11111111-1111-1111-1111-111111111111"
|
||||
get_snippets = Mock(return_value=(snippets, 1, False))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets)
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}]))
|
||||
_patch_snippet_response_properties(monkeypatch)
|
||||
|
||||
api = snippets_module.CustomizedSnippetsApi()
|
||||
handler = unwrap(api.get)
|
||||
@ -87,7 +149,7 @@ def test_list_snippets_returns_pagination(app: Flask, monkeypatch: pytest.Monkey
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {
|
||||
"data": [{"id": "snippet-1"}],
|
||||
"data": [_expected_snippet_list_item(snippets[0])],
|
||||
"page": 2,
|
||||
"limit": 10,
|
||||
"total": 1,
|
||||
@ -110,6 +172,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
|
||||
snippet = _snippet()
|
||||
create_snippet = Mock(return_value=snippet)
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
|
||||
_patch_snippet_response_properties(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
snippets_module.CreateSnippetPayload,
|
||||
"model_validate",
|
||||
@ -124,7 +187,6 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
|
||||
|
||||
api = snippets_module.CustomizedSnippetsApi()
|
||||
handler = unwrap(api.post)
|
||||
@ -137,7 +199,7 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app: Flask, mo
|
||||
response, status_code = handler(api, "tenant-1", user)
|
||||
|
||||
assert status_code == 201
|
||||
assert response == {"id": "snippet-1"}
|
||||
assert response == _expected_snippet_response(snippet)
|
||||
assert create_snippet.call_args.kwargs["snippet_type"] == snippets_module.SnippetType.NODE
|
||||
|
||||
|
||||
@ -184,7 +246,7 @@ def test_get_snippet_detail_raises_when_missing(app: Flask, monkeypatch: pytest.
|
||||
def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
|
||||
_patch_snippet_response_properties(monkeypatch)
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = unwrap(api.get)
|
||||
@ -193,7 +255,7 @@ def test_get_snippet_detail_returns_snippet(app: Flask, monkeypatch: pytest.Monk
|
||||
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"id": "snippet-1"}
|
||||
assert response == _expected_snippet_response(snippet)
|
||||
|
||||
|
||||
def test_patch_snippet_returns_400_for_empty_payload(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -230,7 +292,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet)
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"}))
|
||||
_patch_snippet_response_properties(monkeypatch)
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = unwrap(api.patch)
|
||||
@ -243,7 +305,7 @@ def test_patch_snippet_updates_and_commits(app: Flask, monkeypatch: pytest.Monke
|
||||
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"id": "snippet-1", "name": "New"}
|
||||
assert response == _expected_snippet_response(updated_snippet)
|
||||
update_snippet.assert_called_once()
|
||||
assert update_snippet.call_args.kwargs["data"] == {
|
||||
"name": "New",
|
||||
|
||||
@ -26,7 +26,9 @@ from controllers.console.workspace.workspace import (
|
||||
WebappLogoWorkspaceApi,
|
||||
WorkspaceInfoApi,
|
||||
WorkspaceListApi,
|
||||
WorkspaceLogoUploadResponse,
|
||||
WorkspacePermissionApi,
|
||||
WorkspacePermissionResponse,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -587,7 +589,8 @@ class TestWebappLogoWorkspaceApi:
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file1"
|
||||
assert result == {"id": "file1"}
|
||||
assert WorkspaceLogoUploadResponse.model_validate(result).model_dump(mode="json") == {"id": "file1"}
|
||||
|
||||
def test_filename_missing(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
@ -676,7 +679,7 @@ class TestWorkspaceInfoApi:
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"name": "New Name"},
|
||||
return_value={"id": "t1", "name": "New Name"},
|
||||
),
|
||||
):
|
||||
result = method(api, "t1")
|
||||
@ -717,7 +720,13 @@ class TestWorkspacePermissionApi:
|
||||
result, status = method(api, "t1")
|
||||
|
||||
assert status == 200
|
||||
assert result["workspace_id"] == "t1"
|
||||
expected = {
|
||||
"workspace_id": "t1",
|
||||
"allow_member_invite": True,
|
||||
"allow_owner_transfer": False,
|
||||
}
|
||||
assert result == expected
|
||||
assert WorkspacePermissionResponse.model_validate(result).model_dump(mode="json") == expected
|
||||
|
||||
def test_no_current_tenant(self, app: Flask):
|
||||
api = WorkspacePermissionApi()
|
||||
|
||||
@ -325,10 +325,12 @@ class TestPipelineRunApiEntity:
|
||||
def test_entity_missing_required_field(self):
|
||||
"""Test entity raises on missing required field."""
|
||||
with pytest.raises(ValueError):
|
||||
PipelineRunApiEntity(
|
||||
inputs={},
|
||||
datasource_type="online_document",
|
||||
# missing datasource_info_list, start_node_id, etc.
|
||||
PipelineRunApiEntity.model_validate(
|
||||
{
|
||||
"inputs": {},
|
||||
"datasource_type": "online_document",
|
||||
# missing datasource_info_list, start_node_id, etc.
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -382,8 +384,19 @@ class TestDatasourcePluginsApiGet:
|
||||
mock_dataset = Mock()
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
datasource_plugins = [
|
||||
{
|
||||
"node_id": "node-datasource-1",
|
||||
"plugin_id": "plugin-a",
|
||||
"provider_name": "provider-a",
|
||||
"datasource_type": "online_document",
|
||||
"title": "Online Docs",
|
||||
"user_input_variables": [{"variable": "url", "label": "URL", "type": "text-input", "required": True}],
|
||||
"credentials": [{"id": "cred-1", "name": "Default credential", "type": "oauth2", "is_default": True}],
|
||||
}
|
||||
]
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}]
|
||||
mock_svc_instance.get_datasource_plugins.return_value = datasource_plugins
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"):
|
||||
@ -391,11 +404,33 @@ class TestDatasourcePluginsApiGet:
|
||||
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert response == [{"name": "plugin_a"}]
|
||||
assert response == datasource_plugins
|
||||
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=True
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
|
||||
def test_get_plugins_parses_false_is_published_query(self, mock_svc_cls, mock_db, app: Flask):
|
||||
"""Test false query string is parsed as boolean False."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
dataset_id = str(uuid.uuid4())
|
||||
|
||||
mock_db.session.scalar.return_value = Mock()
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_datasource_plugins.return_value = []
|
||||
mock_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=false"):
|
||||
api = DatasourcePluginsApi()
|
||||
response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert response == []
|
||||
mock_svc_instance.get_datasource_plugins.assert_called_once_with(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=False
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
|
||||
def test_get_plugins_not_found(self, mock_db, app: Flask):
|
||||
"""Test NotFound when dataset check fails."""
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
@ -112,7 +111,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
custom_disclaimer="",
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
@ -138,7 +137,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
body = response
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
@ -167,7 +166,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"custom_disclaimer": "",
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
@ -256,7 +255,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
custom_disclaimer="",
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
@ -277,7 +276,7 @@ def test_get_form_uses_runtime_select_options(monkeypatch: pytest.MonkeyPatch, a
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
body = response
|
||||
assert body["inputs"] == [input_config.model_dump(mode="json") for input_config in runtime_inputs]
|
||||
service_mock.resolve_form_inputs.assert_called_once_with(form)
|
||||
|
||||
@ -380,7 +379,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
custom_disclaimer="",
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
@ -403,7 +402,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
body = response
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
@ -432,7 +431,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"custom_disclaimer": "",
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from flask_restx import marshal
|
||||
|
||||
from fields.snippet_fields import snippet_list_fields
|
||||
from fields.snippet_fields import SnippetListItemResponse
|
||||
from libs.helper import dump_response
|
||||
|
||||
|
||||
def test_snippet_list_fields_include_author_name() -> None:
|
||||
@ -23,6 +22,6 @@ def test_snippet_list_fields_include_author_name() -> None:
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
result = marshal(snippet, snippet_list_fields)
|
||||
result = dump_response(SnippetListItemResponse, snippet)
|
||||
|
||||
assert result["author_name"] == "Alice"
|
||||
|
||||
@ -269,6 +269,7 @@ export type MessageDetailResponse = {
|
||||
agent_thoughts?: Array<AgentThought>
|
||||
annotation?: ConversationAnnotation | null
|
||||
annotation_hit_history?: ConversationAnnotationHitHistory | null
|
||||
answer: string
|
||||
answer_tokens?: number | null
|
||||
conversation_id: string
|
||||
created_at?: number | null
|
||||
@ -284,12 +285,11 @@ export type MessageDetailResponse = {
|
||||
}
|
||||
message?: JsonValue | null
|
||||
message_files?: Array<MessageFile>
|
||||
message_metadata_dict?: JsonValue | null
|
||||
message_tokens?: number | null
|
||||
metadata?: JsonValue | null
|
||||
parent_message_id?: string | null
|
||||
provider_response_latency?: number | null
|
||||
query: string
|
||||
re_sign_file_url_answer: string
|
||||
status: string
|
||||
workflow_run_id?: string | null
|
||||
}
|
||||
@ -732,7 +732,6 @@ export type AgentThought = {
|
||||
created_at?: number | null
|
||||
files: Array<string>
|
||||
id: string
|
||||
message_chain_id?: string | null
|
||||
message_id: string
|
||||
observation?: string | null
|
||||
position: number
|
||||
@ -752,8 +751,8 @@ export type ConversationAnnotation = {
|
||||
|
||||
export type ConversationAnnotationHitHistory = {
|
||||
annotation_create_account?: SimpleAccount | null
|
||||
annotation_id: string
|
||||
created_at?: number | null
|
||||
id: string
|
||||
}
|
||||
|
||||
export type HumanInputContent = {
|
||||
|
||||
@ -628,7 +628,6 @@ export const zAgentThought = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
files: z.array(z.string()),
|
||||
id: z.string(),
|
||||
message_chain_id: z.string().nullish(),
|
||||
message_id: z.string(),
|
||||
observation: z.string().nullish(),
|
||||
position: z.int(),
|
||||
@ -1060,8 +1059,8 @@ export const zConversationAnnotation = z.object({
|
||||
*/
|
||||
export const zConversationAnnotationHitHistory = z.object({
|
||||
annotation_create_account: zSimpleAccount.nullish(),
|
||||
annotation_id: z.string(),
|
||||
created_at: z.int().nullish(),
|
||||
id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
@ -2039,6 +2038,7 @@ export const zMessageDetailResponse = z.object({
|
||||
agent_thoughts: z.array(zAgentThought).optional(),
|
||||
annotation: zConversationAnnotation.nullish(),
|
||||
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
|
||||
answer: z.string(),
|
||||
answer_tokens: z.int().nullish(),
|
||||
conversation_id: z.string(),
|
||||
created_at: z.int().nullish(),
|
||||
@ -2052,12 +2052,11 @@ export const zMessageDetailResponse = z.object({
|
||||
inputs: z.record(z.string(), zJsonValue),
|
||||
message: zJsonValue.nullish(),
|
||||
message_files: z.array(zMessageFile).optional(),
|
||||
message_metadata_dict: zJsonValue.nullish(),
|
||||
message_tokens: z.int().nullish(),
|
||||
metadata: zJsonValue.nullish(),
|
||||
parent_message_id: z.string().nullish(),
|
||||
provider_response_latency: z.number().nullish(),
|
||||
query: z.string(),
|
||||
re_sign_file_url_answer: z.string(),
|
||||
status: z.string(),
|
||||
workflow_run_id: z.string().nullish(),
|
||||
})
|
||||
|
||||
@ -433,6 +433,7 @@ export type MessageDetailResponse = {
|
||||
agent_thoughts?: Array<AgentThought>
|
||||
annotation?: ConversationAnnotation | null
|
||||
annotation_hit_history?: ConversationAnnotationHitHistory | null
|
||||
answer: string
|
||||
answer_tokens?: number | null
|
||||
conversation_id: string
|
||||
created_at?: number | null
|
||||
@ -448,12 +449,11 @@ export type MessageDetailResponse = {
|
||||
}
|
||||
message?: JsonValue | null
|
||||
message_files?: Array<MessageFile>
|
||||
message_metadata_dict?: JsonValue | null
|
||||
message_tokens?: number | null
|
||||
metadata?: JsonValue | null
|
||||
parent_message_id?: string | null
|
||||
provider_response_latency?: number | null
|
||||
query: string
|
||||
re_sign_file_url_answer: string
|
||||
status: string
|
||||
workflow_run_id?: string | null
|
||||
}
|
||||
@ -1440,7 +1440,6 @@ export type AgentThought = {
|
||||
created_at?: number | null
|
||||
files: Array<string>
|
||||
id: string
|
||||
message_chain_id?: string | null
|
||||
message_id: string
|
||||
observation?: string | null
|
||||
position: number
|
||||
@ -1460,8 +1459,8 @@ export type ConversationAnnotation = {
|
||||
|
||||
export type ConversationAnnotationHitHistory = {
|
||||
annotation_create_account?: SimpleAccount | null
|
||||
annotation_id: string
|
||||
created_at?: number | null
|
||||
id: string
|
||||
}
|
||||
|
||||
export type HumanInputContent = {
|
||||
|
||||
@ -1183,7 +1183,6 @@ export const zAgentThought = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
files: z.array(z.string()),
|
||||
id: z.string(),
|
||||
message_chain_id: z.string().nullish(),
|
||||
message_id: z.string(),
|
||||
observation: z.string().nullish(),
|
||||
position: z.int(),
|
||||
@ -2286,8 +2285,8 @@ export const zConversationPagination = z.object({
|
||||
*/
|
||||
export const zConversationAnnotationHitHistory = z.object({
|
||||
annotation_create_account: zSimpleAccount.nullish(),
|
||||
annotation_id: z.string(),
|
||||
created_at: z.int().nullish(),
|
||||
id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
@ -3401,6 +3400,7 @@ export const zMessageDetailResponse = z.object({
|
||||
agent_thoughts: z.array(zAgentThought).optional(),
|
||||
annotation: zConversationAnnotation.nullish(),
|
||||
annotation_hit_history: zConversationAnnotationHitHistory.nullish(),
|
||||
answer: z.string(),
|
||||
answer_tokens: z.int().nullish(),
|
||||
conversation_id: z.string(),
|
||||
created_at: z.int().nullish(),
|
||||
@ -3414,12 +3414,11 @@ export const zMessageDetailResponse = z.object({
|
||||
inputs: z.record(z.string(), zJsonValue),
|
||||
message: zJsonValue.nullish(),
|
||||
message_files: z.array(zMessageFile).optional(),
|
||||
message_metadata_dict: zJsonValue.nullish(),
|
||||
message_tokens: z.int().nullish(),
|
||||
metadata: zJsonValue.nullish(),
|
||||
parent_message_id: z.string().nullish(),
|
||||
provider_response_latency: z.number().nullish(),
|
||||
query: z.string(),
|
||||
re_sign_file_url_answer: z.string(),
|
||||
status: z.string(),
|
||||
workflow_run_id: z.string().nullish(),
|
||||
})
|
||||
|
||||
@ -237,7 +237,6 @@ export type AgentThought = {
|
||||
created_at?: number | null
|
||||
files: Array<string>
|
||||
id: string
|
||||
message_chain_id?: string | null
|
||||
message_id: string
|
||||
observation?: string | null
|
||||
position: number
|
||||
|
||||
@ -257,7 +257,6 @@ export const zAgentThought = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
files: z.array(z.string()),
|
||||
id: z.string(),
|
||||
message_chain_id: z.string().nullish(),
|
||||
message_id: z.string(),
|
||||
observation: z.string().nullish(),
|
||||
position: z.int(),
|
||||
|
||||
@ -217,14 +217,8 @@ const toFeedback = (feedback: NonNullable<MessageDetailResponse['feedbacks']>[nu
|
||||
}
|
||||
}
|
||||
|
||||
type AgentDebugMessageWithLegacyAnswer = MessageDetailResponse & {
|
||||
answer?: string | null
|
||||
}
|
||||
|
||||
const getAgentDebugMessageAnswer = (message: MessageDetailResponse) => {
|
||||
const legacyAnswer = (message as AgentDebugMessageWithLegacyAnswer).answer
|
||||
|
||||
return message.re_sign_file_url_answer ?? legacyAnswer ?? ''
|
||||
return message.answer ?? ''
|
||||
}
|
||||
|
||||
function getFormattedAgentDebugChatTree(messages: MessageDetailResponse[]): ChatItemInTree[] {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user