diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md
index 30920ed983..99e3e3567e 100644
--- a/docs/ar-SA/README.md
+++ b/docs/ar-SA/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md
index 5430364ef9..f3fa68b466 100644
--- a/docs/bn-BD/README.md
+++ b/docs/bn-BD/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md
index 6c49fbdfc3..c71a0bfccf 100644
--- a/docs/de-DE/README.md
+++ b/docs/de-DE/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md
index ae83d416e3..da81b51d6a 100644
--- a/docs/es-ES/README.md
+++ b/docs/es-ES/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md
index b7d006a927..03f3221798 100644
--- a/docs/fr-FR/README.md
+++ b/docs/fr-FR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/hi-IN/README.md b/docs/hi-IN/README.md
index 7c4fc70db0..bedeaa6246 100644
--- a/docs/hi-IN/README.md
+++ b/docs/hi-IN/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/it-IT/README.md b/docs/it-IT/README.md
index 598e87ec25..2e96335d3e 100644
--- a/docs/it-IT/README.md
+++ b/docs/it-IT/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md
index f9e700d1df..659ffbda51 100644
--- a/docs/ja-JP/README.md
+++ b/docs/ja-JP/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md
index 4e4b82e920..2f6c526ef2 100644
--- a/docs/ko-KR/README.md
+++ b/docs/ko-KR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md
index 444faa0a67..ed29ec0294 100644
--- a/docs/pt-BR/README.md
+++ b/docs/pt-BR/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md
index 04dc3b5dff..caef2c303c 100644
--- a/docs/sl-SI/README.md
+++ b/docs/sl-SI/README.md
@@ -33,6 +33,12 @@
+
+
+
+
+
+
diff --git a/docs/tlh/README.md b/docs/tlh/README.md
index b1e3016efd..a25849c443 100644
--- a/docs/tlh/README.md
+++ b/docs/tlh/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md
index 965a1704be..6361ca5dd9 100644
--- a/docs/tr-TR/README.md
+++ b/docs/tr-TR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md
index 07329e84cd..3042a98d95 100644
--- a/docs/vi-VN/README.md
+++ b/docs/vi-VN/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md
index 888a0d7f12..15bb447ad8 100644
--- a/docs/zh-CN/README.md
+++ b/docs/zh-CN/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md
index d8c484a6d4..14b343ba29 100644
--- a/docs/zh-TW/README.md
+++ b/docs/zh-TW/README.md
@@ -36,6 +36,12 @@

+
+ 
+
+ 
+
+
From 4ccc150fd190a9151f0e9d674f18ff5773fb068c Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 07:33:46 -0800
Subject: [PATCH 37/63] test: add comprehensive unit tests for
ExternalDatasetService (external knowledge API integration) (#28716)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/external_dataset_service.py | 920 ++++++++++++++++++
1 file changed, 920 insertions(+)
create mode 100644 api/tests/unit_tests/services/external_dataset_service.py
diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py
new file mode 100644
index 0000000000..1647eb3e85
--- /dev/null
+++ b/api/tests/unit_tests/services/external_dataset_service.py
@@ -0,0 +1,920 @@
+"""
+Extensive unit tests for ``ExternalDatasetService``.
+
+This module focuses on the *external dataset service* surface area, which is responsible
+for integrating with **external knowledge APIs** and wiring them into Dify datasets.
+
+The goal of this test suite is twofold:
+
+- Provide **high‑confidence regression coverage** for all public helpers on
+ ``ExternalDatasetService``.
+- Serve as **executable documentation** for how external API integration is expected
+ to behave in different scenarios (happy paths, validation failures, and error codes).
+
+The file intentionally contains **rich comments and generous spacing** in order to make
+each scenario easy to scan during reviews.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import MagicMock, Mock, patch
+
+import httpx
+import pytest
+
+from constants import HIDDEN_VALUE
+from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
+from services.entities.external_knowledge_entities.external_knowledge_entities import (
+ Authorization,
+ AuthorizationConfig,
+ ExternalKnowledgeApiSetting,
+)
+from services.errors.dataset import DatasetNameDuplicateError
+from services.external_knowledge_service import ExternalDatasetService
+
+
+class ExternalDatasetTestDataFactory:
+ """
+ Factory helpers for building *lightweight* mocks for external knowledge tests.
+
+ These helpers are intentionally small and explicit:
+
+ - They avoid pulling in unnecessary fixtures.
+ - They reflect the minimal contract that the service under test cares about.
+ """
+
+ @staticmethod
+ def create_external_api(
+ api_id: str = "api-123",
+ tenant_id: str = "tenant-1",
+ name: str = "Test API",
+ description: str = "Description",
+ settings: dict | None = None,
+ ) -> ExternalKnowledgeApis:
+ """
+ Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
+
+ Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
+ exercise ``settings_dict`` and other convenience properties if needed.
+ """
+
+ instance = ExternalKnowledgeApis(
+ tenant_id=tenant_id,
+ name=name,
+ description=description,
+ settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
+ )
+
+ # Overwrite generated id for determinism in assertions.
+ instance.id = api_id
+ return instance
+
+ @staticmethod
+ def create_dataset(
+ dataset_id: str = "ds-1",
+ tenant_id: str = "tenant-1",
+ name: str = "External Dataset",
+ provider: str = "external",
+ ) -> Dataset:
+ """
+ Build a small ``Dataset`` instance representing an external dataset.
+ """
+
+ dataset = Dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description="",
+ provider=provider,
+ created_by="user-1",
+ )
+ dataset.id = dataset_id
+ return dataset
+
+ @staticmethod
+ def create_external_binding(
+ tenant_id: str = "tenant-1",
+ dataset_id: str = "ds-1",
+ api_id: str = "api-1",
+ external_knowledge_id: str = "knowledge-1",
+ ) -> ExternalKnowledgeBindings:
+ """
+ Small helper for a binding between dataset and external knowledge API.
+ """
+
+ binding = ExternalKnowledgeBindings(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ external_knowledge_api_id=api_id,
+ external_knowledge_id=external_knowledge_id,
+ created_by="user-1",
+ )
+ return binding
+
+
+# ---------------------------------------------------------------------------
+# get_external_knowledge_apis
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceGetExternalKnowledgeApis:
+ """
+ Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
+
+ These tests focus on:
+
+ - Basic pagination wiring via ``db.paginate``.
+ - Optional search keyword behaviour.
+ """
+
+ @pytest.fixture
+ def mock_db_paginate(self):
+ """
+ Patch ``db.paginate`` so we do not touch the real database layer.
+ """
+
+ with (
+ patch("services.external_knowledge_service.db.paginate") as mock_paginate,
+ patch("services.external_knowledge_service.select"),
+ ):
+ yield mock_paginate
+
+ def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
+ """
+ It should return ``items`` and ``total`` coming from the paginate object.
+ """
+
+ # Arrange
+ tenant_id = "tenant-1"
+ page = 1
+ per_page = 20
+
+ mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
+ mock_pagination = SimpleNamespace(items=mock_items, total=42)
+ mock_db_paginate.return_value = mock_pagination
+
+ # Act
+ items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
+
+ # Assert
+ assert items is mock_items
+ assert total == 42
+
+ mock_db_paginate.assert_called_once()
+ call_kwargs = mock_db_paginate.call_args.kwargs
+ assert call_kwargs["page"] == page
+ assert call_kwargs["per_page"] == per_page
+ assert call_kwargs["max_per_page"] == 100
+ assert call_kwargs["error_out"] is False
+
+ def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
+ """
+ When a search keyword is provided, the query should be adjusted
+ (we simply assert that paginate is still called and does not explode).
+ """
+
+ # Arrange
+ tenant_id = "tenant-1"
+ page = 2
+ per_page = 10
+ search = "foo"
+
+ mock_pagination = SimpleNamespace(items=[], total=0)
+ mock_db_paginate.return_value = mock_pagination
+
+ # Act
+ items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
+
+ # Assert
+ assert items == []
+ assert total == 0
+ mock_db_paginate.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# validate_api_list
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceValidateApiList:
+ """
+ Lightweight validation tests for ``validate_api_list``.
+ """
+
+ def test_validate_api_list_success(self):
+ """
+ A minimal valid configuration (endpoint + api_key) should pass.
+ """
+
+ config = {"endpoint": "https://example.com", "api_key": "secret"}
+
+ # Act & Assert – no exception expected
+ ExternalDatasetService.validate_api_list(config)
+
+ @pytest.mark.parametrize(
+ ("config", "expected_message"),
+ [
+ ({}, "api list is empty"),
+ ({"api_key": "k"}, "endpoint is required"),
+ ({"endpoint": "https://example.com"}, "api_key is required"),
+ ],
+ )
+ def test_validate_api_list_failures(self, config: dict, expected_message: str):
+ """
+ Invalid configs should raise ``ValueError`` with a clear message.
+ """
+
+ with pytest.raises(ValueError, match=expected_message):
+ ExternalDatasetService.validate_api_list(config)
+
+
+# ---------------------------------------------------------------------------
+# create_external_knowledge_api & get/update/delete
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceCrudExternalKnowledgeApi:
+ """
+ CRUD tests for external knowledge API templates.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Patch ``db.session`` for all CRUD tests in this class.
+ """
+
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
+ """
+ ``create_external_knowledge_api`` should persist a new record
+ when settings are present and valid.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ args = {
+ "name": "API",
+ "description": "desc",
+ "settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
+ }
+
+ # We do not want to actually call the remote endpoint here, so we patch the validator.
+ with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
+ result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
+
+ assert isinstance(result, ExternalKnowledgeApis)
+ mock_check.assert_called_once_with(args["settings"])
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
+ """
+ Missing ``settings`` should result in a ``ValueError``.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ args = {"name": "API", "description": "desc"}
+
+ with pytest.raises(ValueError, match="settings is required"):
+ ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
+
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
+ """
+ ``get_external_knowledge_api`` should return the first matching record.
+ """
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
+
+ result = ExternalDatasetService.get_external_knowledge_api("api-id")
+ assert result is api
+
+ def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ When the record is absent, a ``ValueError`` is raised.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.get_external_knowledge_api("missing-id")
+
+ def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
+ """
+ Updating an API should keep the existing API key when the special hidden
+ value placeholder is sent from the UI.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ api_id = "api-1"
+
+ existing_api = Mock(spec=ExternalKnowledgeApis)
+ existing_api.settings_dict = {"api_key": "stored-key"}
+ existing_api.settings = '{"api_key":"stored-key"}'
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
+
+ args = {
+ "name": "New Name",
+ "description": "New Desc",
+ "settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
+ }
+
+ result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
+
+ assert result is existing_api
+ # The placeholder should be replaced with stored key.
+ assert args["settings"]["api_key"] == "stored-key"
+ mock_db_session.commit.assert_called_once()
+
+ def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Updating a non‑existent API template should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.update_external_knowledge_api(
+ tenant_id="tenant-1",
+ user_id="user-1",
+ external_knowledge_api_id="missing-id",
+ args={"name": "n", "description": "d", "settings": {}},
+ )
+
+ def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
+ """
+ ``delete_external_knowledge_api`` should delete and commit when found.
+ """
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
+
+ ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
+
+ mock_db_session.delete.assert_called_once_with(api)
+ mock_db_session.commit.assert_called_once()
+
+ def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Deletion of a missing template should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
+
+
+# ---------------------------------------------------------------------------
+# external_knowledge_api_use_check & binding lookups
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceUsageAndBindings:
+ """
+ Tests for usage checks and dataset binding retrieval.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
+ """
+ When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
+
+ in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
+
+ assert in_use is True
+ assert count == 3
+
+ def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
+ """
+ Zero bindings should return ``(False, 0)``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
+
+ in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
+
+ assert in_use is False
+ assert count == 0
+
+ def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
+ """
+ Binding lookup should return the first record when present.
+ """
+
+ binding = Mock(spec=ExternalKnowledgeBindings)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
+
+ result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
+ assert result is binding
+
+ def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Missing binding should result in a ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="external knowledge binding not found"):
+ ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
+
+
+# ---------------------------------------------------------------------------
+# document_create_args_validate
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceDocumentCreateArgsValidate:
+ """
+ Tests for ``document_create_args_validate``.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
+ """
+ All required custom parameters present – validation should pass.
+ """
+
+ external_api = Mock(spec=ExternalKnowledgeApis)
+ external_api.settings = json_settings = (
+ '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
+ )
+ # Raw string; the service itself calls json.loads on it
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
+
+ process_parameter = {"foo": "value", "bar": "optional"}
+
+ # Act & Assert – no exception
+ ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
+
+ assert json_settings in external_api.settings # simple sanity check on our test data
+
+ def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
+ """
+ When the referenced API template is missing, a ``ValueError`` is raised.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
+
+ def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
+ """
+ Required document process parameters must be supplied.
+ """
+
+ external_api = Mock(spec=ExternalKnowledgeApis)
+ external_api.settings = (
+ '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
+ )
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
+
+ process_parameter = {"bar": "present"} # missing "foo"
+
+ with pytest.raises(ValueError, match="foo is required"):
+ ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
+
+
+# ---------------------------------------------------------------------------
+# process_external_api
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceProcessExternalApi:
+ """
+ Tests focused on the HTTP request assembly and method mapping behaviour.
+ """
+
+ def test_process_external_api_valid_method_post(self):
+ """
+ For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
+ """
+
+ settings = ExternalKnowledgeApiSetting(
+ url="https://example.com/path",
+ request_method="POST",
+ headers={"X-Test": "1"},
+ params={"foo": "bar"},
+ )
+
+ fake_response = httpx.Response(200)
+
+ with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
+ mock_post.return_value = fake_response
+
+ result = ExternalDatasetService.process_external_api(settings, files=None)
+
+ assert result is fake_response
+ mock_post.assert_called_once()
+ kwargs = mock_post.call_args.kwargs
+ assert kwargs["url"] == settings.url
+ assert kwargs["headers"] == settings.headers
+ assert kwargs["follow_redirects"] is True
+ assert "data" in kwargs
+
+ def test_process_external_api_invalid_method_raises(self):
+ """
+ An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
+ """
+
+ settings = ExternalKnowledgeApiSetting(
+ url="https://example.com",
+ request_method="INVALID",
+ headers=None,
+ params={},
+ )
+
+ from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
+
+ with pytest.raises(InvalidHttpMethodError):
+ ExternalDatasetService.process_external_api(settings, files=None)
+
+
+# ---------------------------------------------------------------------------
+# assembling_headers
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceAssemblingHeaders:
+ """
+ Tests for header assembly based on different authentication flavours.
+ """
+
+ def test_assembling_headers_bearer_token(self):
+ """
+ For bearer auth we expect ``Authorization: Bearer `` by default.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth)
+
+ assert headers["Authorization"] == "Bearer secret"
+
+ def test_assembling_headers_basic_token_with_custom_header(self):
+ """
+ For basic auth we honour the configured header name.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
+
+ assert headers["Existing"] == "1"
+ assert headers["X-Auth"] == "Basic abc123"
+
+ def test_assembling_headers_custom_type(self):
+ """
+ Custom auth type should inject the raw API key.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth, headers=None)
+
+ assert headers["X-API-KEY"] == "raw-key"
+
+ def test_assembling_headers_missing_config_raises(self):
+ """
+ Missing config object should be rejected.
+ """
+
+ auth = Authorization(type="api-key", config=None)
+
+ with pytest.raises(ValueError, match="authorization config is required"):
+ ExternalDatasetService.assembling_headers(auth)
+
+ def test_assembling_headers_missing_api_key_raises(self):
+ """
+ ``api_key`` is required when type is ``api-key``.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
+ )
+
+ with pytest.raises(ValueError, match="api_key is required"):
+ ExternalDatasetService.assembling_headers(auth)
+
+ def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
+ """
+ For ``no-auth`` we should not modify the headers mapping.
+ """
+
+ auth = Authorization(type="no-auth", config=None)
+
+ base_headers = {"X": "1"}
+ result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
+
+ # A copy is returned, original is not mutated.
+ assert result == base_headers
+ assert result is not base_headers
+
+
+# ---------------------------------------------------------------------------
+# get_external_knowledge_api_settings
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
+ """
+ Simple shape test for ``get_external_knowledge_api_settings``.
+ """
+
+ def test_get_external_knowledge_api_settings(self):
+ settings_dict: dict[str, Any] = {
+ "url": "https://example.com/retrieval",
+ "request_method": "post",
+ "headers": {"Content-Type": "application/json"},
+ "params": {"foo": "bar"},
+ }
+
+ result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
+
+ assert isinstance(result, ExternalKnowledgeApiSetting)
+ assert result.url == settings_dict["url"]
+ assert result.request_method == settings_dict["request_method"]
+ assert result.headers == settings_dict["headers"]
+ assert result.params == settings_dict["params"]
+
+
+# ---------------------------------------------------------------------------
+# create_external_dataset
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceCreateExternalDataset:
+ """
+ Tests around creating the external dataset and its binding row.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_create_external_dataset_success(self, mock_db_session: MagicMock):
+ """
+ A brand new dataset name with valid external knowledge references
+ should create both the dataset and its binding.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+
+ args = {
+ "name": "My Dataset",
+ "description": "desc",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": "knowledge-1",
+ "external_retrieval_model": {"top_k": 3},
+ }
+
+ # No existing dataset with same name.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None, # duplicate‑name check
+ Mock(spec=ExternalKnowledgeApis), # external knowledge api
+ ]
+
+ dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
+
+ assert isinstance(dataset, Dataset)
+ assert dataset.provider == "external"
+ assert dataset.retrieval_model == args["external_retrieval_model"]
+
+ assert mock_db_session.add.call_count >= 2 # dataset + binding
+ mock_db_session.flush.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
+ """
+ When a dataset with the same name already exists,
+ ``DatasetNameDuplicateError`` is raised.
+ """
+
+ existing_dataset = Mock(spec=Dataset)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
+
+ args = {
+ "name": "Existing",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": "knowledge-1",
+ }
+
+ with pytest.raises(DatasetNameDuplicateError):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
+
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
+ """
+ If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
+ """
+
+ # First call: duplicate name check – not found.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None,
+ None, # external knowledge api lookup
+ ]
+
+ args = {
+ "name": "Dataset",
+ "external_knowledge_api_id": "missing",
+ "external_knowledge_id": "knowledge-1",
+ }
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
+
+ def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
+ """
+ ``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
+ """
+
+ # duplicate name check
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None,
+ Mock(spec=ExternalKnowledgeApis),
+ ]
+
+ args_missing_knowledge_id = {
+ "name": "Dataset",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": None,
+ }
+
+ with pytest.raises(ValueError, match="external_knowledge_id is required"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
+
+ args_missing_api_id = {
+ "name": "Dataset",
+ "external_knowledge_api_id": None,
+ "external_knowledge_id": "k-1",
+ }
+
+ with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
+
+
+# ---------------------------------------------------------------------------
+# fetch_external_knowledge_retrieval
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
+ """
+ Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
+ external retrieval requests and normalises the response payload.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
+ """
+ With a valid binding and API template, records from the external
+ service should be returned when the HTTP response is 200.
+ """
+
+ tenant_id = "tenant-1"
+ dataset_id = "ds-1"
+ query = "test query"
+ external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ api_id="api-1",
+ external_knowledge_id="knowledge-1",
+ )
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
+
+ # First query: binding; second query: api.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ api,
+ ]
+
+ fake_records = [{"content": "doc", "score": 0.9}]
+ fake_response = Mock(spec=httpx.Response)
+ fake_response.status_code = 200
+ fake_response.json.return_value = {"records": fake_records}
+
+ metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
+
+ with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
+ result = ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ query=query,
+ external_retrieval_parameters=external_retrieval_parameters,
+ metadata_condition=metadata_condition,
+ )
+
+ assert result == fake_records
+
+ mock_process.assert_called_once()
+ setting_arg = mock_process.call_args.args[0]
+ assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
+ assert setting_arg.url.endswith("/retrieval")
+
+ def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Missing binding should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="external knowledge binding not found"):
+ ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="missing",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
+ """
+ When the API template is missing or has no settings, a ``ValueError`` is raised.
+ """
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding()
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ None,
+ ]
+
+ with pytest.raises(ValueError, match="external api template not found"):
+ ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="ds-1",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
+ """
+ Non‑200 responses should be treated as an empty result set.
+ """
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding()
+ api = Mock(spec=ExternalKnowledgeApis)
+ api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
+
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ api,
+ ]
+
+ fake_response = Mock(spec=httpx.Response)
+ fake_response.status_code = 500
+ fake_response.json.return_value = {}
+
+ with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
+ result = ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="ds-1",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ assert result == []
From 38522e5dfa38831d44655faef068a525852f7ea2 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Thu, 27 Nov 2025 08:39:49 +0800
Subject: [PATCH 38/63] fix: use default_factory for callable defaults in ORM
dataclasses (#28730)
---
api/models/account.py | 24 +++++--
api/models/api_based_extension.py | 4 +-
api/models/dataset.py | 106 +++++++++++++++++++++++++-----
api/models/model.py | 52 +++++++++++----
api/models/oauth.py | 12 +++-
api/models/provider.py | 40 ++++++++---
api/models/source.py | 8 ++-
api/models/task.py | 7 +-
api/models/tools.py | 44 +++++++++----
api/models/trigger.py | 36 +++++++---
api/models/web.py | 8 ++-
api/models/workflow.py | 4 +-
12 files changed, 269 insertions(+), 76 deletions(-)
diff --git a/api/models/account.py b/api/models/account.py
index b1dafed0ed..420e6adc6c 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -88,7 +88,9 @@ class Account(UserMixin, TypeBase):
__tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255), default=None)
@@ -235,7 +237,9 @@ class Tenant(TypeBase):
__tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
@@ -275,7 +279,9 @@ class TenantAccountJoin(TypeBase):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
@@ -297,7 +303,9 @@ class AccountIntegrate(TypeBase):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
@@ -348,7 +356,9 @@ class TenantPluginPermission(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
@@ -375,7 +385,9 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py
index 99d33908f8..b5acab5a75 100644
--- a/api/models/api_based_extension.py
+++ b/api/models/api_based_extension.py
@@ -24,7 +24,9 @@ class APIBasedExtension(TypeBase):
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 2ea6d98b5f..e072711b82 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -920,7 +920,12 @@ class AppDatasetJoin(TypeBase):
)
id: Mapped[str] = mapped_column(
- StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
+ StringUUID,
+ primary_key=True,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -941,7 +946,12 @@ class DatasetQuery(TypeBase):
)
id: Mapped[str] = mapped_column(
- StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
+ StringUUID,
+ primary_key=True,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
@@ -961,7 +971,13 @@ class DatasetKeywordTable(TypeBase):
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False, unique=True)
keyword_table: Mapped[str] = mapped_column(LongText, nullable=False)
data_source_type: Mapped[str] = mapped_column(
@@ -1012,7 +1028,13 @@ class Embedding(TypeBase):
sa.Index("created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
model_name: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")
)
@@ -1037,7 +1059,13 @@ class DatasetCollectionBinding(TypeBase):
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
@@ -1073,7 +1101,13 @@ class Whitelist(TypeBase):
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
sa.Index("whitelists_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
category: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
@@ -1090,7 +1124,13 @@ class DatasetPermission(TypeBase):
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), primary_key=True, init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ primary_key=True,
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1110,7 +1150,13 @@ class ExternalKnowledgeApis(TypeBase):
sa.Index("external_knowledge_apis_name_idx", "name"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1167,7 +1213,13 @@ class ExternalKnowledgeBindings(TypeBase):
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1191,7 +1243,9 @@ class DatasetAutoDisableLog(TypeBase):
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1209,7 +1263,9 @@ class RateLimitLog(TypeBase):
sa.Index("rate_limit_log_operation_idx", "operation"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
operation: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1226,7 +1282,9 @@ class DatasetMetadata(TypeBase):
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1255,7 +1313,9 @@ class DatasetMetadataBinding(TypeBase):
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1270,7 +1330,9 @@ class PipelineBuiltInTemplate(TypeBase):
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@@ -1300,7 +1362,9 @@ class PipelineCustomizedTemplate(TypeBase):
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
@@ -1335,7 +1399,9 @@ class Pipeline(TypeBase):
__tablename__ = "pipelines"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''"))
@@ -1368,7 +1434,9 @@ class DocumentPipelineExecutionLog(TypeBase):
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@@ -1385,7 +1453,9 @@ class PipelineRecommendedPlugin(TypeBase):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
diff --git a/api/models/model.py b/api/models/model.py
index 33a94628f0..1731ff5699 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -572,7 +572,9 @@ class InstalledApp(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_owner_tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -606,7 +608,9 @@ class OAuthProviderApp(TypeBase):
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1311,7 +1315,9 @@ class MessageFeedback(TypeBase):
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1360,7 +1366,9 @@ class MessageFile(TypeBase):
sa.Index("message_file_created_by_idx", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
@@ -1452,7 +1460,9 @@ class AppAnnotationSetting(TypeBase):
sa.Index("app_annotation_settings_app_idx", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1488,7 +1498,9 @@ class OperationLog(TypeBase):
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1554,7 +1566,9 @@ class AppMCPServer(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1764,7 +1778,9 @@ class ApiRequest(TypeBase):
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
path: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1783,7 +1799,9 @@ class MessageChain(TypeBase):
sa.Index("message_chain_message_id_idx", "message_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
@@ -1914,7 +1932,9 @@ class DatasetRetrieverResource(TypeBase):
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1946,7 +1966,9 @@ class Tag(TypeBase):
TAG_TYPE_LIST = ["knowledge", "app"]
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1964,7 +1986,9 @@ class TagBinding(TypeBase):
sa.Index("tag_bind_tag_id_idx", "tag_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
tag_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
target_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
@@ -1981,7 +2005,9 @@ class TraceAppConfig(TypeBase):
sa.Index("trace_app_config_app_id_idx", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
diff --git a/api/models/oauth.py b/api/models/oauth.py
index 2fce67c998..1db2552469 100644
--- a/api/models/oauth.py
+++ b/api/models/oauth.py
@@ -17,7 +17,9 @@ class DatasourceOauthParamConfig(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
@@ -30,7 +32,9 @@ class DatasourceProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
@@ -60,7 +64,9 @@ class DatasourceOauthTenantParamConfig(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
diff --git a/api/models/provider.py b/api/models/provider.py
index 577e098a2e..2afd8c5329 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -58,7 +58,13 @@ class Provider(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuidv7()),
+ default_factory=lambda: str(uuidv7()),
+ init=False,
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
@@ -132,7 +138,9 @@ class ProviderModel(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -173,7 +181,9 @@ class TenantDefaultModel(TypeBase):
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -193,7 +203,9 @@ class TenantPreferredModelProvider(TypeBase):
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
@@ -212,7 +224,9 @@ class ProviderOrder(TypeBase):
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -245,7 +259,9 @@ class ProviderModelSetting(TypeBase):
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -273,7 +289,9 @@ class LoadBalancingModelConfig(TypeBase):
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -302,7 +320,9 @@ class ProviderCredential(TypeBase):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -332,7 +352,9 @@ class ProviderModelCredential(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/source.py b/api/models/source.py
index f093048c00..a8addbe342 100644
--- a/api/models/source.py
+++ b/api/models/source.py
@@ -18,7 +18,9 @@ class DataSourceOauthBinding(TypeBase):
adjusted_json_index("source_info_idx", "source_info"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -44,7 +46,9 @@ class DataSourceApiKeyAuthBinding(TypeBase):
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/task.py b/api/models/task.py
index 539945b251..d98d99ca2c 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -24,7 +24,8 @@ class CeleryTask(TypeBase):
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(
DateTime,
- default=naive_utc_now,
+ insert_default=naive_utc_now,
+ default=None,
onupdate=naive_utc_now,
nullable=True,
)
@@ -47,4 +48,6 @@ class CeleryTaskSet(TypeBase):
)
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
- date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
+ date_done: Mapped[datetime | None] = mapped_column(
+ DateTime, insert_default=naive_utc_now, default=None, nullable=True
+ )
diff --git a/api/models/tools.py b/api/models/tools.py
index 0a79f95a70..e4f9bcb582 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -30,7 +30,9 @@ class ToolOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
@@ -45,7 +47,9 @@ class ToolOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -71,7 +75,9 @@ class BuiltinToolProvider(TypeBase):
)
# id of the tool provider
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(
String(256),
nullable=False,
@@ -120,7 +126,9 @@ class ApiToolProvider(TypeBase):
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the api provider
name: Mapped[str] = mapped_column(
String(255),
@@ -192,7 +200,9 @@ class ToolLabelBinding(TypeBase):
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
@@ -213,7 +223,9 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
@@ -279,7 +291,9 @@ class MCPToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the mcp provider
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
@@ -360,7 +374,9 @@ class ToolModelInvoke(TypeBase):
__tablename__ = "tool_model_invokes"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# who invoke this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@@ -413,7 +429,9 @@ class ToolConversationVariables(TypeBase):
sa.Index("conversation_id_idx", "conversation_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@@ -450,7 +468,9 @@ class ToolFile(TypeBase):
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
@@ -481,7 +501,9 @@ class DeprecatedPublishedAppTool(TypeBase):
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# id of the app
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
diff --git a/api/models/trigger.py b/api/models/trigger.py
index 088e797f82..87e2a5ccfc 100644
--- a/api/models/trigger.py
+++ b/api/models/trigger.py
@@ -41,7 +41,9 @@ class TriggerSubscription(TypeBase):
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -111,7 +113,9 @@ class TriggerOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the trigger provider
@@ -136,7 +140,9 @@ class TriggerOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -202,7 +208,9 @@ class WorkflowTriggerLog(TypeBase):
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -294,7 +302,9 @@ class WorkflowWebhookTrigger(TypeBase):
sa.UniqueConstraint("webhook_id", name="uniq_webhook_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -351,7 +361,9 @@ class WorkflowPluginTrigger(TypeBase):
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -395,7 +407,9 @@ class AppTrigger(TypeBase):
sa.Index("app_trigger_tenant_app_idx", "tenant_id", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
@@ -443,7 +457,13 @@ class WorkflowSchedulePlan(TypeBase):
sa.Index("workflow_schedule_plan_next_idx", "next_run_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuidv7()),
+ default_factory=lambda: str(uuidv7()),
+ init=False,
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
diff --git a/api/models/web.py b/api/models/web.py
index 4f0bf7c7da..b2832aa163 100644
--- a/api/models/web.py
+++ b/api/models/web.py
@@ -18,7 +18,9 @@ class SavedMessage(TypeBase):
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
@@ -42,7 +44,9 @@ class PinnedConversation(TypeBase):
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role: Mapped[str] = mapped_column(
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 4efa829692..42ee8a1f2b 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1103,7 +1103,9 @@ class WorkflowAppLog(TypeBase):
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
From 64babb35e2c6e75808fab81739b83d2aa6fe8821 Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 17:55:42 -0800
Subject: [PATCH 39/63] feat: Add comprehensive unit tests for
DatasetCollectionBindingService (dataset collection binding methods) (#28724)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/dataset_collection_binding.py | 932 ++++++++++++++++++
1 file changed, 932 insertions(+)
create mode 100644 api/tests/unit_tests/services/dataset_collection_binding.py
diff --git a/api/tests/unit_tests/services/dataset_collection_binding.py b/api/tests/unit_tests/services/dataset_collection_binding.py
new file mode 100644
index 0000000000..2a939a5c1d
--- /dev/null
+++ b/api/tests/unit_tests/services/dataset_collection_binding.py
@@ -0,0 +1,932 @@
+"""
+Comprehensive unit tests for DatasetCollectionBindingService.
+
+This module contains extensive unit tests for the DatasetCollectionBindingService class,
+which handles dataset collection binding operations for vector database collections.
+
+The DatasetCollectionBindingService provides methods for:
+- Retrieving or creating dataset collection bindings by provider, model, and type
+- Retrieving specific collection bindings by ID and type
+- Managing collection bindings for different collection types (dataset, etc.)
+
+Collection bindings are used to map embedding models (provider + model name) to
+specific vector database collections, allowing datasets to share collections when
+they use the same embedding model configuration.
+
+This test suite ensures:
+- Correct retrieval of existing bindings
+- Proper creation of new bindings when they don't exist
+- Accurate filtering by provider, model, and collection type
+- Proper error handling for missing bindings
+- Database transaction handling (add, commit)
+- Collection name generation using Dataset.gen_collection_name_by_id
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The DatasetCollectionBindingService is a critical component in the Dify platform's
+vector database management system. It serves as an abstraction layer between the
+application logic and the underlying vector database collections.
+
+Key Concepts:
+1. Collection Binding: A mapping between an embedding model configuration
+ (provider + model name) and a vector database collection name. This allows
+ multiple datasets to share the same collection when they use identical
+ embedding models, improving resource efficiency.
+
+2. Collection Type: Different types of collections can exist (e.g., "dataset",
+ "custom_type"). This allows for separation of collections based on their
+ intended use case or data structure.
+
+3. Provider and Model: The combination of provider_name (e.g., "openai",
+ "cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
+ uniquely identifies an embedding model configuration.
+
+4. Collection Name Generation: When a new binding is created, a unique collection
+ name is generated using Dataset.gen_collection_name_by_id() with a UUID.
+ This ensures each binding has a unique collection identifier.
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. Happy Path Scenarios:
+ - Successful retrieval of existing bindings
+ - Successful creation of new bindings
+ - Proper handling of default parameters
+
+2. Edge Cases:
+ - Different collection types
+ - Various provider/model combinations
+ - Default vs explicit parameter usage
+
+3. Error Handling:
+ - Missing bindings (for get_by_id_and_type)
+ - Database query failures
+ - Invalid parameter combinations
+
+4. Database Interaction:
+ - Query construction and execution
+ - Transaction management (add, commit)
+ - Query chaining (where, order_by, first)
+
+5. Mocking Strategy:
+ - Database session mocking
+ - Query builder chain mocking
+ - UUID generation mocking
+ - Collection name generation mocking
+
+================================================================================
+"""
+
+"""
+Import statements for the test module.
+
+This section imports all necessary dependencies for testing the
+DatasetCollectionBindingService, including:
+- unittest.mock for creating mock objects
+- pytest for test framework functionality
+- uuid for UUID generation (used in collection name generation)
+- Models and services from the application codebase
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from models.dataset import Dataset, DatasetCollectionBinding
+from services.dataset_service import DatasetCollectionBindingService
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
+# changes, we only need to update the factory methods rather than every
+# individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class DatasetCollectionBindingTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for dataset collection binding tests.
+
+ This factory provides static methods to create mock objects for:
+ - DatasetCollectionBinding instances
+ - Database query results
+ - Collection name generation results
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_collection_binding_mock(
+ binding_id: str = "binding-123",
+ provider_name: str = "openai",
+ model_name: str = "text-embedding-ada-002",
+ collection_name: str = "collection-abc",
+ collection_type: str = "dataset",
+ created_at=None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetCollectionBinding with specified attributes.
+
+ Args:
+ binding_id: Unique identifier for the binding
+ provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
+ model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
+ collection_name: Name of the vector database collection
+ collection_type: Type of collection (default: "dataset")
+ created_at: Optional datetime for creation timestamp
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetCollectionBinding instance
+ """
+ binding = Mock(spec=DatasetCollectionBinding)
+ binding.id = binding_id
+ binding.provider_name = provider_name
+ binding.model_name = model_name
+ binding.collection_name = collection_name
+ binding.type = collection_type
+ binding.created_at = created_at
+ for key, value in kwargs.items():
+ setattr(binding, key, value)
+ return binding
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset for testing collection name generation.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+
+# ============================================================================
+# Tests for get_dataset_collection_binding
+# ============================================================================
+
+
+class TestDatasetCollectionBindingServiceGetBinding:
+ """
+ Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
+
+ This test class covers the main collection binding retrieval/creation functionality,
+ including various provider/model combinations, collection types, and edge cases.
+
+ The get_dataset_collection_binding method:
+ 1. Queries for existing binding by provider_name, model_name, and collection_type
+ 2. Orders results by created_at (ascending) and takes the first match
+ 3. If no binding exists, creates a new one with:
+ - The provided provider_name and model_name
+ - A generated collection_name using Dataset.gen_collection_name_by_id
+ - The provided collection_type
+ 4. Adds the new binding to the database session and commits
+ 5. Returns the binding (either existing or newly created)
+
+ Test scenarios include:
+ - Retrieving existing bindings
+ - Creating new bindings when none exist
+ - Different collection types
+ - Database transaction handling
+ - Collection name generation
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction and execution
+ - Add operations for new bindings
+ - Commit operations for transaction completion
+
+ The mock is configured to return a query builder that supports
+ chaining operations like .where(), .order_by(), and .first().
+ """
+ with patch("services.dataset_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
+ """
+ Test successful retrieval of an existing collection binding.
+
+ Verifies that when a binding already exists in the database for the given
+ provider, model, and collection type, the method returns the existing binding
+ without creating a new one.
+
+ This test ensures:
+ - The query is constructed correctly with all three filters
+ - Results are ordered by created_at
+ - The first matching binding is returned
+ - No new binding is created (db.session.add is not called)
+ - No commit is performed (db.session.commit is not called)
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-123",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain: query().where().order_by().first()
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == "binding-123"
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+ assert result.type == collection_type
+
+ # Verify query was constructed correctly
+ # The query should be constructed with DatasetCollectionBinding as the model
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied to filter by provider, model, and type
+ mock_query.where.assert_called_once()
+
+ # Verify the results were ordered by created_at (ascending)
+ # This ensures we get the oldest binding if multiple exist
+ mock_where.order_by.assert_called_once()
+
+ # Verify no new binding was created
+ # Since an existing binding was found, we should not create a new one
+ mock_db_session.add.assert_not_called()
+
+ # Verify no commit was performed
+ # Since no new binding was created, no database transaction is needed
+ mock_db_session.commit.assert_not_called()
+
+ def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
+ """
+ Test successful creation of a new collection binding when none exists.
+
+ Verifies that when no binding exists in the database for the given
+ provider, model, and collection type, the method creates a new binding
+ with a generated collection name and commits it to the database.
+
+ This test ensures:
+ - The query returns None (no existing binding)
+ - A new DatasetCollectionBinding is created with correct attributes
+ - Dataset.gen_collection_name_by_id is called to generate collection name
+ - The new binding is added to the database session
+ - The transaction is committed
+ - The newly created binding is returned
+ """
+ # Arrange
+ provider_name = "cohere"
+ model_name = "embed-english-v3.0"
+ collection_type = "dataset"
+ generated_collection_name = "collection-generated-xyz"
+
+ # Mock the query chain to return None (no existing binding)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No existing binding
+ mock_db_session.query.return_value = mock_query
+
+ # Mock Dataset.gen_collection_name_by_id to return a generated name
+ with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
+ mock_gen_name.return_value = generated_collection_name
+
+ # Mock uuid.uuid4 for the collection name generation
+ mock_uuid = "test-uuid-123"
+ with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result is not None
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+ assert result.type == collection_type
+ assert result.collection_name == generated_collection_name
+
+ # Verify Dataset.gen_collection_name_by_id was called with the generated UUID
+ # This method generates a unique collection name based on the UUID
+ # The UUID is converted to string before passing to the method
+ mock_gen_name.assert_called_once_with(str(mock_uuid))
+
+ # Verify new binding was added to the database session
+ # The add method should be called exactly once with the new binding instance
+ mock_db_session.add.assert_called_once()
+
+ # Extract the binding that was added to verify its properties
+ added_binding = mock_db_session.add.call_args[0][0]
+
+ # Verify the added binding is an instance of DatasetCollectionBinding
+ # This ensures we're creating the correct type of object
+ assert isinstance(added_binding, DatasetCollectionBinding)
+
+ # Verify all the binding properties are set correctly
+ # These should match the input parameters to the method
+ assert added_binding.provider_name == provider_name
+ assert added_binding.model_name == model_name
+ assert added_binding.type == collection_type
+
+ # Verify the collection name was set from the generated name
+ # This ensures the binding has a valid collection identifier
+ assert added_binding.collection_name == generated_collection_name
+
+ # Verify the transaction was committed
+ # This ensures the new binding is persisted to the database
+ mock_db_session.commit.assert_called_once()
+
+ def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
+ """
+ Test retrieval with a different collection type (not "dataset").
+
+ Verifies that the method correctly filters by collection_type, allowing
+ different types of collections to coexist with the same provider/model
+ combination.
+
+ This test ensures:
+ - Collection type is properly used as a filter in the query
+ - Different collection types can have separate bindings
+ - The correct binding is returned based on type
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ collection_type = "custom_type"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-456",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.type == collection_type
+
+ # Verify query was constructed with the correct type filter
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
+ """
+ Test retrieval with default collection type ("dataset").
+
+ Verifies that when collection_type is not provided, it defaults to "dataset"
+ as specified in the method signature.
+
+ This test ensures:
+ - The default value "dataset" is used when type is not specified
+ - The query correctly filters by the default type
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ # collection_type defaults to "dataset" in method signature
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-789",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type="dataset", # Default type
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act - call without specifying collection_type (uses default)
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.type == "dataset"
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
+ """
+ Test retrieval with different provider/model combinations.
+
+ Verifies that bindings are correctly filtered by both provider_name and
+ model_name, ensuring that different model combinations have separate bindings.
+
+ This test ensures:
+ - Provider and model are both used as filters
+ - Different combinations result in different bindings
+ - The correct binding is returned for each combination
+ """
+ # Arrange
+ provider_name = "huggingface"
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-hf-123",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+
+ # Verify query filters were applied correctly
+ # The query should filter by both provider_name and model_name
+ # This ensures different model combinations have separate bindings
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied with all three filters:
+ # - provider_name filter
+ # - model_name filter
+ # - collection_type filter
+ mock_query.where.assert_called_once()
+
+
+# ============================================================================
+# Tests for get_dataset_collection_binding_by_id_and_type
+# ============================================================================
+# This section contains tests for the get_dataset_collection_binding_by_id_and_type
+# method, which retrieves a specific collection binding by its ID and type.
+#
+# Key differences from get_dataset_collection_binding:
+# 1. This method queries by ID and type, not by provider/model/type
+# 2. This method does NOT create a new binding if one doesn't exist
+# 3. This method raises ValueError if the binding is not found
+# 4. This method is typically used when you already know the binding ID
+#
+# Use cases:
+# - Retrieving a binding that was previously created
+# - Validating that a binding exists before using it
+# - Accessing binding metadata when you have the ID
+#
+# ============================================================================
+
+
+class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
+ """
+ Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
+
+ This test class covers collection binding retrieval by ID and type,
+ including success scenarios and error handling for missing bindings.
+
+ The get_dataset_collection_binding_by_id_and_type method:
+ 1. Queries for a binding by collection_binding_id and collection_type
+ 2. Orders results by created_at (ascending) and takes the first match
+ 3. If no binding exists, raises ValueError("Dataset collection binding not found")
+ 4. Returns the found binding
+
+ Unlike get_dataset_collection_binding, this method does NOT create a new
+ binding if one doesn't exist - it only retrieves existing bindings.
+
+ Test scenarios include:
+ - Successful retrieval of existing bindings
+ - Error handling for missing bindings
+ - Different collection types
+ - Default collection type behavior
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction with ID and type filters
+ - Ordering by created_at
+ - First result retrieval
+
+ The mock is configured to return a query builder that supports
+ chaining operations like .where(), .order_by(), and .first().
+ """
+ with patch("services.dataset_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
+ """
+ Test successful retrieval of a collection binding by ID and type.
+
+ Verifies that when a binding exists in the database with the given
+ ID and collection type, the method returns the binding.
+
+ This test ensures:
+ - The query is constructed correctly with ID and type filters
+ - Results are ordered by created_at
+ - The first matching binding is returned
+ - No error is raised
+ """
+ # Arrange
+ collection_binding_id = "binding-123"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="openai",
+ model_name="text-embedding-ada-002",
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain: query().where().order_by().first()
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == collection_type
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+ mock_where.order_by.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
+ """
+ Test error handling when binding is not found.
+
+ Verifies that when no binding exists in the database with the given
+ ID and collection type, the method raises a ValueError with the
+ message "Dataset collection binding not found".
+
+ This test ensures:
+ - The query returns None (no existing binding)
+ - ValueError is raised with the correct message
+ - No binding is returned
+ """
+ # Arrange
+ collection_binding_id = "non-existent-binding"
+ collection_type = "dataset"
+
+ # Mock the query chain to return None (no existing binding)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No existing binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Dataset collection binding not found"):
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Verify query was attempted
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
+ """
+ Test retrieval with a different collection type.
+
+ Verifies that the method correctly filters by collection_type, ensuring
+ that bindings with the same ID but different types are treated as
+ separate entities.
+
+ This test ensures:
+ - Collection type is properly used as a filter in the query
+ - Different collection types can have separate bindings with same ID
+ - The correct binding is returned based on type
+ """
+ # Arrange
+ collection_binding_id = "binding-456"
+ collection_type = "custom_type"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="cohere",
+ model_name="embed-english-v3.0",
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == collection_type
+
+ # Verify query was constructed with the correct type filter
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
+ """
+ Test retrieval with default collection type ("dataset").
+
+ Verifies that when collection_type is not provided, it defaults to "dataset"
+ as specified in the method signature.
+
+ This test ensures:
+ - The default value "dataset" is used when type is not specified
+ - The query correctly filters by the default type
+ - The correct binding is returned
+ """
+ # Arrange
+ collection_binding_id = "binding-789"
+ # collection_type defaults to "dataset" in method signature
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="openai",
+ model_name="text-embedding-ada-002",
+ collection_type="dataset", # Default type
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act - call without specifying collection_type (uses default)
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == "dataset"
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
+ """
+ Test error handling when binding exists but with wrong collection type.
+
+ Verifies that when a binding exists with the given ID but a different
+ collection type, the method raises a ValueError because the binding
+ doesn't match both the ID and type criteria.
+
+ This test ensures:
+ - The query correctly filters by both ID and type
+ - Bindings with matching ID but different type are not returned
+ - ValueError is raised when no matching binding is found
+ """
+ # Arrange
+ collection_binding_id = "binding-123"
+ collection_type = "dataset"
+
+ # Mock the query chain to return None (binding exists but with different type)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No matching binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Dataset collection binding not found"):
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Verify query was attempted with both ID and type filters
+ # The query should filter by both collection_binding_id and collection_type
+ # This ensures we only get bindings that match both criteria
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied with both filters:
+ # - collection_binding_id filter (exact match)
+ # - collection_type filter (exact match)
+ mock_query.where.assert_called_once()
+
+ # Note: The order_by and first() calls are also part of the query chain,
+ # but we don't need to verify them separately since they're part of the
+ # standard query pattern used by both methods in this service.
+
+
+# ============================================================================
+# Additional Test Scenarios and Edge Cases
+# ============================================================================
+# The following section could contain additional test scenarios if needed:
+#
+# Potential additional tests:
+# 1. Test with multiple existing bindings (verify ordering by created_at)
+# 2. Test with very long provider/model names (boundary testing)
+# 3. Test with special characters in provider/model names
+# 4. Test concurrent binding creation (thread safety)
+# 5. Test database rollback scenarios
+# 6. Test with None values for optional parameters
+# 7. Test with empty strings for required parameters
+# 8. Test collection name generation uniqueness
+# 9. Test with different UUID formats
+# 10. Test query performance with large datasets
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
+
+
+# ============================================================================
+# Integration Notes and Best Practices
+# ============================================================================
+#
+# When using DatasetCollectionBindingService in production code, consider:
+#
+# 1. Error Handling:
+# - Always handle ValueError exceptions when calling
+# get_dataset_collection_binding_by_id_and_type
+# - Check return values from get_dataset_collection_binding to ensure
+# bindings were created successfully
+#
+# 2. Performance Considerations:
+# - The service queries the database on every call, so consider caching
+# bindings if they're accessed frequently
+# - Collection bindings are typically long-lived, so caching is safe
+#
+# 3. Transaction Management:
+# - New bindings are automatically committed to the database
+# - If you need to rollback, ensure you're within a transaction context
+#
+# 4. Collection Type Usage:
+# - Use "dataset" for standard dataset collections
+# - Use custom types only when you need to separate collections by purpose
+# - Be consistent with collection type naming across your application
+#
+# 5. Provider and Model Naming:
+# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
+# - Use exact model names as provided by the model provider
+# - These names are case-sensitive and must match exactly
+#
+# ============================================================================
+
+
+# ============================================================================
+# Database Schema Reference
+# ============================================================================
+#
+# The DatasetCollectionBinding model has the following structure:
+#
+# - id: StringUUID (primary key, auto-generated)
+# - provider_name: String(255) (required, e.g., "openai", "cohere")
+# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
+# - type: String(40) (required, default: "dataset")
+# - collection_name: String(64) (required, unique collection identifier)
+# - created_at: DateTime (auto-generated timestamp)
+#
+# Indexes:
+# - Primary key on id
+# - Composite index on (provider_name, model_name) for efficient lookups
+#
+# Relationships:
+# - One binding can be referenced by multiple datasets
+# - Datasets reference bindings via collection_binding_id
+#
+# ============================================================================
+
+
+# ============================================================================
+# Mocking Strategy Documentation
+# ============================================================================
+#
+# This test suite uses extensive mocking to isolate the unit under test.
+# Here's how the mocking strategy works:
+#
+# 1. Database Session Mocking:
+# - db.session is patched to prevent actual database access
+# - Query chains are mocked to return predictable results
+# - Add and commit operations are tracked for verification
+#
+# 2. Query Chain Mocking:
+# - query() returns a mock query object
+# - where() returns a mock where object
+# - order_by() returns a mock order_by object
+# - first() returns the final result (binding or None)
+#
+# 3. UUID Generation Mocking:
+# - uuid.uuid4() is mocked to return predictable UUIDs
+# - This ensures collection names are generated consistently in tests
+#
+# 4. Collection Name Generation Mocking:
+# - Dataset.gen_collection_name_by_id() is mocked
+# - This allows us to verify the method is called correctly
+# - We can control the generated collection name for testing
+#
+# Benefits of this approach:
+# - Tests run quickly (no database I/O)
+# - Tests are deterministic (no random UUIDs)
+# - Tests are isolated (no side effects)
+# - Tests are maintainable (clear mock setup)
+#
+# ============================================================================
From 0fdb4e7c12330216fbcbf674815c795f3a97d9e7 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 20:57:52 -0500
Subject: [PATCH 40/63] chore: enhance the test script of conversation service
(#28739)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/test_conversation_service.py | 1412 ++++++++++++++++-
1 file changed, 1339 insertions(+), 73 deletions(-)
diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py
index 9c1c044f03..81135dbbdf 100644
--- a/api/tests/unit_tests/services/test_conversation_service.py
+++ b/api/tests/unit_tests/services/test_conversation_service.py
@@ -1,17 +1,293 @@
+"""
+Comprehensive unit tests for ConversationService.
+
+This test suite provides complete coverage of conversation management operations in Dify,
+following TDD principles with the Arrange-Act-Assert pattern.
+
+## Test Coverage
+
+### 1. Conversation Pagination (TestConversationServicePagination)
+Tests conversation listing and filtering:
+- Empty include_ids returns empty results
+- Non-empty include_ids filters conversations properly
+- Empty exclude_ids doesn't filter results
+- Non-empty exclude_ids excludes specified conversations
+- Null user handling
+- Sorting and pagination edge cases
+
+### 2. Message Creation (TestConversationServiceMessageCreation)
+Tests message operations within conversations:
+- Message pagination without first_id
+- Message pagination with first_id specified
+- Error handling for non-existent messages
+- Empty result handling for null user/conversation
+- Message ordering (ascending/descending)
+- Has_more flag calculation
+
+### 3. Conversation Summarization (TestConversationServiceSummarization)
+Tests auto-generated conversation names:
+- Successful LLM-based name generation
+- Error handling when conversation has no messages
+- Graceful handling of LLM service failures
+- Manual vs auto-generated naming
+- Name update timestamp tracking
+
+### 4. Message Annotation (TestConversationServiceMessageAnnotation)
+Tests annotation creation and management:
+- Creating annotations from existing messages
+- Creating standalone annotations
+- Updating existing annotations
+- Paginated annotation retrieval
+- Annotation search with keywords
+- Annotation export functionality
+
+### 5. Conversation Export (TestConversationServiceExport)
+Tests data retrieval for export:
+- Successful conversation retrieval
+- Error handling for non-existent conversations
+- Message retrieval
+- Annotation export
+- Batch data export operations
+
+## Testing Approach
+
+- **Mocking Strategy**: All external dependencies (database, LLM, Redis) are mocked
+ for fast, isolated unit tests
+- **Factory Pattern**: ConversationServiceTestDataFactory provides consistent test data
+- **Fixtures**: Mock objects are configured per test method
+- **Assertions**: Each test verifies return values and side effects
+ (database operations, method calls)
+
+## Key Concepts
+
+**Conversation Sources:**
+- console: Created by workspace members
+- api: Created by end users via API
+
+**Message Pagination:**
+- first_id: Paginate from a specific message forward
+- last_id: Paginate from a specific message backward
+- Supports ascending/descending order
+
+**Annotations:**
+- Can be attached to messages or standalone
+- Support full-text search
+- Indexed for semantic retrieval
+"""
+
import uuid
-from unittest.mock import MagicMock, patch
+from datetime import UTC, datetime
+from decimal import Decimal
+from unittest.mock import MagicMock, Mock, create_autospec, patch
+
+import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
+from models import Account
+from models.model import App, Conversation, EndUser, Message, MessageAnnotation
+from services.annotation_service import AppAnnotationService
from services.conversation_service import ConversationService
+from services.errors.conversation import ConversationNotExistsError
+from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError
+from services.message_service import MessageService
-class TestConversationService:
+class ConversationServiceTestDataFactory:
+ """
+ Factory for creating test data and mock objects.
+
+ Provides reusable methods to create consistent mock objects for testing
+ conversation-related operations.
+ """
+
+ @staticmethod
+ def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
+ """
+ Create a mock Account object.
+
+ Args:
+ account_id: Unique identifier for the account
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Account object with specified attributes
+ """
+ account = create_autospec(Account, instance=True)
+ account.id = account_id
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
+ """
+ Create a mock EndUser object.
+
+ Args:
+ user_id: Unique identifier for the end user
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock EndUser object with specified attributes
+ """
+ user = create_autospec(EndUser, instance=True)
+ user.id = user_id
+ for key, value in kwargs.items():
+ setattr(user, key, value)
+ return user
+
+ @staticmethod
+ def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
+ """
+ Create a mock App object.
+
+ Args:
+ app_id: Unique identifier for the app
+ tenant_id: Tenant/workspace identifier
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock App object with specified attributes
+ """
+ app = create_autospec(App, instance=True)
+ app.id = app_id
+ app.tenant_id = tenant_id
+ app.name = kwargs.get("name", "Test App")
+ app.mode = kwargs.get("mode", "chat")
+ app.status = kwargs.get("status", "normal")
+ for key, value in kwargs.items():
+ setattr(app, key, value)
+ return app
+
+ @staticmethod
+ def create_conversation_mock(
+ conversation_id: str = "conv-123",
+ app_id: str = "app-123",
+ from_source: str = "console",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Conversation object.
+
+ Args:
+ conversation_id: Unique identifier for the conversation
+ app_id: Associated app identifier
+ from_source: Source of conversation ('console' or 'api')
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Conversation object with specified attributes
+ """
+ conversation = create_autospec(Conversation, instance=True)
+ conversation.id = conversation_id
+ conversation.app_id = app_id
+ conversation.from_source = from_source
+ conversation.from_end_user_id = kwargs.get("from_end_user_id")
+ conversation.from_account_id = kwargs.get("from_account_id")
+ conversation.is_deleted = kwargs.get("is_deleted", False)
+ conversation.name = kwargs.get("name", "Test Conversation")
+ conversation.status = kwargs.get("status", "normal")
+ conversation.created_at = kwargs.get("created_at", datetime.now(UTC))
+ conversation.updated_at = kwargs.get("updated_at", datetime.now(UTC))
+ for key, value in kwargs.items():
+ setattr(conversation, key, value)
+ return conversation
+
+ @staticmethod
+ def create_message_mock(
+ message_id: str = "msg-123",
+ conversation_id: str = "conv-123",
+ app_id: str = "app-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Message object.
+
+ Args:
+ message_id: Unique identifier for the message
+ conversation_id: Associated conversation identifier
+ app_id: Associated app identifier
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Message object with specified attributes including
+ query, answer, tokens, and pricing information
+ """
+ message = create_autospec(Message, instance=True)
+ message.id = message_id
+ message.conversation_id = conversation_id
+ message.app_id = app_id
+ message.query = kwargs.get("query", "Test query")
+ message.answer = kwargs.get("answer", "Test answer")
+ message.from_source = kwargs.get("from_source", "console")
+ message.from_end_user_id = kwargs.get("from_end_user_id")
+ message.from_account_id = kwargs.get("from_account_id")
+ message.created_at = kwargs.get("created_at", datetime.now(UTC))
+ message.message = kwargs.get("message", {})
+ message.message_tokens = kwargs.get("message_tokens", 0)
+ message.answer_tokens = kwargs.get("answer_tokens", 0)
+ message.message_unit_price = kwargs.get("message_unit_price", Decimal(0))
+ message.answer_unit_price = kwargs.get("answer_unit_price", Decimal(0))
+ message.message_price_unit = kwargs.get("message_price_unit", Decimal("0.001"))
+ message.answer_price_unit = kwargs.get("answer_price_unit", Decimal("0.001"))
+ message.currency = kwargs.get("currency", "USD")
+ message.status = kwargs.get("status", "normal")
+ for key, value in kwargs.items():
+ setattr(message, key, value)
+ return message
+
+ @staticmethod
+ def create_annotation_mock(
+ annotation_id: str = "anno-123",
+ app_id: str = "app-123",
+ message_id: str = "msg-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock MessageAnnotation object.
+
+ Args:
+ annotation_id: Unique identifier for the annotation
+ app_id: Associated app identifier
+ message_id: Associated message identifier (optional for standalone annotations)
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock MessageAnnotation object with specified attributes including
+ question, content, and hit tracking
+ """
+ annotation = create_autospec(MessageAnnotation, instance=True)
+ annotation.id = annotation_id
+ annotation.app_id = app_id
+ annotation.message_id = message_id
+ annotation.conversation_id = kwargs.get("conversation_id")
+ annotation.question = kwargs.get("question", "Test question")
+ annotation.content = kwargs.get("content", "Test annotation")
+ annotation.account_id = kwargs.get("account_id", "account-123")
+ annotation.hit_count = kwargs.get("hit_count", 0)
+ annotation.created_at = kwargs.get("created_at", datetime.now(UTC))
+ annotation.updated_at = kwargs.get("updated_at", datetime.now(UTC))
+ for key, value in kwargs.items():
+ setattr(annotation, key, value)
+ return annotation
+
+
+class TestConversationServicePagination:
+ """Test conversation pagination operations."""
+
def test_pagination_with_empty_include_ids(self):
- """Test that empty include_ids returns empty result"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ """
+ Test that empty include_ids returns empty result.
+ When include_ids is an empty list, the service should short-circuit
+ and return empty results without querying the database.
+ """
+ # Arrange - Set up test data
+ mock_session = MagicMock() # Mock database session
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Act - Call the service method with empty include_ids
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
@@ -19,25 +295,188 @@ class TestConversationService:
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
- include_ids=[], # Empty include_ids should return empty result
+ include_ids=[], # Empty list should trigger early return
exclude_ids=None,
)
+ # Assert - Verify empty result without database query
+ assert result.data == [] # No conversations returned
+ assert result.has_more is False # No more pages available
+ assert result.limit == 20 # Limit preserved in response
+
+ def test_pagination_with_non_empty_include_ids(self):
+ """
+ Test that non-empty include_ids filters properly.
+
+ When include_ids contains conversation IDs, the query should filter
+ to only return conversations matching those IDs.
+ """
+ # Arrange - Set up test data and mocks
+ mock_session = MagicMock() # Mock database session
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Create 3 mock conversations that would match the filter
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(3)
+ ]
+ # Mock the database query results
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0 # No additional conversations beyond current page
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=["conv1", "conv2"],
+ exclude_ids=None,
+ )
+
+ # Assert
+ assert mock_stmt.where.called
+
+ def test_pagination_with_empty_exclude_ids(self):
+ """
+ Test that empty exclude_ids doesn't filter.
+
+ When exclude_ids is an empty list, the query should not filter out
+ any conversations.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(5)
+ ]
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=None,
+ exclude_ids=[],
+ )
+
+ # Assert
+ assert len(result.data) == 5
+
+ def test_pagination_with_non_empty_exclude_ids(self):
+ """
+ Test that non-empty exclude_ids filters properly.
+
+ When exclude_ids contains conversation IDs, the query should filter
+ out conversations matching those IDs.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(3)
+ ]
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=None,
+ exclude_ids=["conv1", "conv2"],
+ )
+
+ # Assert
+ assert mock_stmt.where.called
+
+ def test_pagination_returns_empty_when_user_is_none(self):
+ """
+ Test that pagination returns empty result when user is None.
+
+ This ensures proper handling of unauthenticated requests.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+
+ # Act
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=None, # No user provided
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ )
+
+ # Assert - should return empty result without querying database
assert result.data == []
assert result.has_more is False
assert result.limit == 20
- def test_pagination_with_non_empty_include_ids(self):
- """Test that non-empty include_ids filters properly"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ def test_pagination_with_sorting_descending(self):
+ """
+ Test pagination with descending sort order.
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
+ Verifies that conversations are sorted by updated_at in descending order (newest first).
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Create conversations with different timestamps
+ conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(
+ conversation_id=f"conv-{i}", updated_at=datetime(2024, 1, i + 1, tzinfo=UTC)
+ )
+ for i in range(3)
+ ]
+ mock_session.scalars.return_value.all.return_value = conversations
mock_session.scalar.return_value = 0
+ # Act
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
@@ -53,75 +492,902 @@ class TestConversationService:
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
- include_ids=["conv1", "conv2"], # Non-empty include_ids
- exclude_ids=None,
+ sort_by="-updated_at", # Descending sort
)
- # Verify the where clause was called with id.in_
- assert mock_stmt.where.called
+ # Assert
+ assert len(result.data) == 3
+ mock_stmt.order_by.assert_called()
- def test_pagination_with_empty_exclude_ids(self):
- """Test that empty exclude_ids doesn't filter"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
- mock_session.scalar.return_value = 0
+class TestConversationServiceMessageCreation:
+ """
+ Test message creation and pagination.
- with patch("services.conversation_service.select") as mock_select:
- mock_stmt = MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
- mock_stmt.order_by.return_value = mock_stmt
- mock_stmt.limit.return_value = mock_stmt
- mock_stmt.subquery.return_value = MagicMock()
+ Tests MessageService operations for creating and retrieving messages
+ within conversations.
+ """
- result = ConversationService.pagination_by_last_id(
- session=mock_session,
- app_model=mock_app_model,
- user=mock_user,
- last_id=None,
- limit=20,
- invoke_from=InvokeFrom.WEB_APP,
- include_ids=None,
- exclude_ids=[], # Empty exclude_ids should not filter
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination without specifying first_id.
+
+ When first_id is None, the service should return the most recent messages
+ up to the specified limit.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create 3 test messages in the conversation
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
+ )
+ for i in range(3)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act - Call the pagination method without first_id
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None, # No starting point specified
+ limit=10,
+ )
+
+ # Assert - Verify the results
+ assert len(result.data) == 3 # All 3 messages returned
+ assert result.has_more is False # No more messages available (3 < limit of 10)
+ # Verify conversation was looked up with correct parameters
+ mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination with first_id specified.
+
+ When first_id is provided, the service should return messages starting
+ from the specified message up to the limit.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ first_message = ConversationServiceTestDataFactory.create_message_mock(
+ message_id="msg-first", conversation_id=conversation.id
+ )
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
+ )
+ for i in range(2)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.first.return_value = first_message # First message returned
+ mock_query.all.return_value = messages # Remaining messages returned
+
+ # Act - Call the pagination method with first_id
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id="msg-first",
+ limit=10,
+ )
+
+ # Assert - Verify the results
+ assert len(result.data) == 2 # Only 2 messages returned after first_id
+ assert result.has_more is False # No more messages available (2 < limit of 10)
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_raises_error_when_first_message_not_found(
+ self, mock_get_conversation, mock_db_session
+ ):
+ """
+ Test that FirstMessageNotExistsError is raised when first_id doesn't exist.
+
+ When the specified first_id does not exist in the conversation,
+ the service should raise an error.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.first.return_value = None # No message found for first_id
+
+ # Act & Assert
+ with pytest.raises(FirstMessageNotExistsError):
+ MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id="non-existent-msg",
+ limit=10,
)
- # Result should contain the mocked conversations
- assert len(result.data) == 5
+ def test_pagination_returns_empty_when_no_user(self):
+ """
+ Test that pagination returns empty result when user is None.
- def test_pagination_with_non_empty_exclude_ids(self):
- """Test that non-empty exclude_ids filters properly"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ This ensures proper handling of unauthenticated requests.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
- mock_session.scalar.return_value = 0
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=None,
+ conversation_id="conv-123",
+ first_id=None,
+ limit=10,
+ )
- with patch("services.conversation_service.select") as mock_select:
- mock_stmt = MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
- mock_stmt.order_by.return_value = mock_stmt
- mock_stmt.limit.return_value = mock_stmt
- mock_stmt.subquery.return_value = MagicMock()
+ # Assert
+ assert result.data == []
+ assert result.has_more is False
- result = ConversationService.pagination_by_last_id(
- session=mock_session,
- app_model=mock_app_model,
- user=mock_user,
- last_id=None,
- limit=20,
- invoke_from=InvokeFrom.WEB_APP,
- include_ids=None,
- exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
+ def test_pagination_returns_empty_when_no_conversation_id(self):
+ """
+ Test that pagination returns empty result when conversation_id is None.
+
+ This ensures proper handling of invalid requests.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id="",
+ first_id=None,
+ limit=10,
+ )
+
+ # Assert
+ assert result.data == []
+ assert result.has_more is False
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
+ """
+ Test that has_more flag is correctly set when there are more messages.
+
+ The service fetches limit+1 messages to determine if more exist.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create limit+1 messages to trigger has_more
+ limit = 5
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
)
+ for i in range(limit + 1) # One extra message
+ ]
- # Verify the where clause was called for exclusion
- assert mock_stmt.where.called
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None,
+ limit=limit,
+ )
+
+ # Assert
+ assert len(result.data) == limit # Extra message should be removed
+ assert result.has_more is True # Flag should be set
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination with ascending order.
+
+ Messages should be returned in chronological order (oldest first).
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create messages with different timestamps
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id, created_at=datetime(2024, 1, i + 1, tzinfo=UTC)
+ )
+ for i in range(3)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None,
+ limit=10,
+ order="asc", # Ascending order
+ )
+
+ # Assert
+ assert len(result.data) == 3
+ # Messages should be in ascending order after reversal
+
+
+class TestConversationServiceSummarization:
+ """
+ Test conversation summarization (auto-generated names).
+
+ Tests the auto_generate_name functionality that creates conversation
+ titles based on the first message.
+ """
+
+ @patch("services.conversation_service.LLMGenerator.generate_conversation_name")
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_success(self, mock_db_session, mock_llm_generator):
+ """
+ Test successful auto-generation of conversation name.
+
+ The service uses an LLM to generate a descriptive name based on
+ the first message in the conversation.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create the first message that will be used to generate the name
+ first_message = ConversationServiceTestDataFactory.create_message_mock(
+ conversation_id=conversation.id, query="What is machine learning?"
+ )
+ # Expected name from LLM
+ generated_name = "Machine Learning Discussion"
+
+ # Set up database query mock to return the first message
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = first_message # Return the first message
+
+ # Mock the LLM to return our expected name
+ mock_llm_generator.return_value = generated_name
+
+ # Act
+ result = ConversationService.auto_generate_name(app_model, conversation)
+
+ # Assert
+ assert conversation.name == generated_name # Name updated on conversation object
+ # Verify LLM was called with correct parameters
+ mock_llm_generator.assert_called_once_with(
+ app_model.tenant_id, first_message.query, conversation.id, app_model.id
+ )
+ mock_db_session.commit.assert_called_once() # Changes committed to database
+
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_raises_error_when_no_message(self, mock_db_session):
+ """
+ Test that MessageNotExistsError is raised when conversation has no messages.
+
+ When the conversation has no messages, the service should raise an error.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Set up database query mock to return no messages
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = None # No messages found
+
+ # Act & Assert
+ with pytest.raises(MessageNotExistsError):
+ ConversationService.auto_generate_name(app_model, conversation)
+
+ @patch("services.conversation_service.LLMGenerator.generate_conversation_name")
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_db_session, mock_llm_generator):
+ """
+ Test that LLM generation failures are suppressed and don't crash.
+
+ When the LLM fails to generate a name, the service should not crash
+ and should return the original conversation name.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ first_message = ConversationServiceTestDataFactory.create_message_mock(conversation_id=conversation.id)
+ original_name = conversation.name
+
+ # Set up database query mock to return the first message
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = first_message # Return the first message
+
+ # Mock the LLM to raise an exception
+ mock_llm_generator.side_effect = Exception("LLM service unavailable")
+
+ # Act
+ result = ConversationService.auto_generate_name(app_model, conversation)
+
+ # Assert
+ assert conversation.name == original_name # Name remains unchanged
+ mock_db_session.commit.assert_called_once() # Changes committed to database
+
+ @patch("services.conversation_service.db.session")
+ @patch("services.conversation_service.ConversationService.get_conversation")
+ @patch("services.conversation_service.ConversationService.auto_generate_name")
+ def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
+ """
+ Test renaming conversation with auto-generation enabled.
+
+ When auto_generate is True, the service should call the auto_generate_name
+ method to generate a new name for the conversation.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ conversation.name = "Auto-generated Name"
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Mock the auto_generate_name method to return the conversation
+ mock_auto_generate.return_value = conversation
+
+ # Act
+ result = ConversationService.rename(
+ app_model=app_model,
+ conversation_id=conversation.id,
+ user=user,
+ name="",
+ auto_generate=True,
+ )
+
+ # Assert
+ mock_auto_generate.assert_called_once_with(app_model, conversation)
+ assert result == conversation
+
+ @patch("services.conversation_service.db.session")
+ @patch("services.conversation_service.ConversationService.get_conversation")
+ @patch("services.conversation_service.naive_utc_now")
+ def test_rename_with_manual_name(self, mock_naive_utc_now, mock_get_conversation, mock_db_session):
+ """
+ Test renaming conversation with manual name.
+
+ When auto_generate is False, the service should update the conversation
+ name with the provided manual name.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ new_name = "My Custom Conversation Name"
+ mock_time = datetime(2024, 1, 1, 12, 0, 0)
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Mock the current time to return our mock time
+ mock_naive_utc_now.return_value = mock_time
+
+ # Act
+ result = ConversationService.rename(
+ app_model=app_model,
+ conversation_id=conversation.id,
+ user=user,
+ name=new_name,
+ auto_generate=False,
+ )
+
+ # Assert
+ assert conversation.name == new_name
+ assert conversation.updated_at == mock_time
+ mock_db_session.commit.assert_called_once()
+
+
+class TestConversationServiceMessageAnnotation:
+ """
+ Test message annotation operations.
+
+ Tests AppAnnotationService operations for creating and managing
+ message annotations.
+ """
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_create_annotation_from_message(self, mock_current_account, mock_db_session):
+ """
+ Test creating annotation from existing message.
+
+ Annotations can be attached to messages to provide curated responses
+ that override the AI-generated answers.
+ """
+ # Arrange
+ app_id = "app-123"
+ message_id = "msg-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Create a message that doesn't have an annotation yet
+ message = ConversationServiceTestDataFactory.create_message_mock(
+ message_id=message_id, app_id=app_id, query="What is AI?"
+ )
+ message.annotation = None # No existing annotation
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns message, third returns None (no annotation setting)
+ mock_query.first.side_effect = [app, message, None]
+
+ # Annotation data to create
+ args = {"message_id": message_id, "answer": "AI is artificial intelligence"}
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once() # Annotation added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_create_annotation_without_message(self, mock_current_account, mock_db_session):
+ """
+ Test creating standalone annotation without message.
+
+ Annotations can be created without a message reference for bulk imports
+ or manual annotation creation.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns None (no message)
+ mock_query.first.side_effect = [app, None]
+
+ # Annotation data to create
+ args = {
+ "question": "What is natural language processing?",
+ "answer": "NLP is a field of AI focused on language understanding",
+ }
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once() # Annotation added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_update_existing_annotation(self, mock_current_account, mock_db_session):
+ """
+ Test updating an existing annotation.
+
+ When a message already has an annotation, calling the service again
+ should update the existing annotation rather than creating a new one.
+ """
+ # Arrange
+ app_id = "app-123"
+ message_id = "msg-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ message = ConversationServiceTestDataFactory.create_message_mock(message_id=message_id, app_id=app_id)
+
+ # Create an existing annotation with old content
+ existing_annotation = ConversationServiceTestDataFactory.create_annotation_mock(
+ app_id=app_id, message_id=message_id, content="Old annotation"
+ )
+ message.annotation = existing_annotation # Message already has annotation
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns message, third returns None (no annotation setting)
+ mock_query.first.side_effect = [app, message, None]
+
+ # New content to update the annotation with
+ args = {"message_id": message_id, "answer": "Updated annotation content"}
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ assert existing_annotation.content == "Updated annotation content" # Content updated
+ mock_db_session.add.assert_called_once() # Annotation re-added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.paginate")
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_get_annotation_list(self, mock_current_account, mock_db_session, mock_db_paginate):
+ """
+ Test retrieving paginated annotation list.
+
+ Annotations can be retrieved in a paginated list for display in the UI.
+ """
+ """Test retrieving paginated annotation list."""
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(annotation_id=f"anno-{i}", app_id=app_id)
+ for i in range(5)
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = app
+
+ mock_paginate = MagicMock()
+ mock_paginate.items = annotations
+ mock_paginate.total = 5
+ mock_db_paginate.return_value = mock_paginate
+
+ # Act
+ result_items, result_total = AppAnnotationService.get_annotation_list_by_app_id(
+ app_id=app_id, page=1, limit=10, keyword=""
+ )
+
+ # Assert
+ assert len(result_items) == 5
+ assert result_total == 5
+
+ @patch("services.annotation_service.db.paginate")
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_get_annotation_list_with_keyword_search(self, mock_current_account, mock_db_session, mock_db_paginate):
+ """
+ Test retrieving annotations with keyword filtering.
+
+ Annotations can be searched by question or content using case-insensitive matching.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Create annotations with searchable content
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(
+ annotation_id="anno-1",
+ app_id=app_id,
+ question="What is machine learning?",
+ content="ML is a subset of AI",
+ ),
+ ConversationServiceTestDataFactory.create_annotation_mock(
+ annotation_id="anno-2",
+ app_id=app_id,
+ question="What is deep learning?",
+ content="Deep learning uses neural networks",
+ ),
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = app
+
+ mock_paginate = MagicMock()
+ mock_paginate.items = [annotations[0]] # Only first annotation matches
+ mock_paginate.total = 1
+ mock_db_paginate.return_value = mock_paginate
+
+ # Act
+ result_items, result_total = AppAnnotationService.get_annotation_list_by_app_id(
+ app_id=app_id,
+ page=1,
+ limit=10,
+ keyword="machine", # Search keyword
+ )
+
+ # Assert
+ assert len(result_items) == 1
+ assert result_total == 1
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_insert_annotation_directly(self, mock_current_account, mock_db_session):
+ """
+ Test direct annotation insertion without message reference.
+
+ This is used for bulk imports or manual annotation creation.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.side_effect = [app, None]
+
+ args = {
+ "question": "What is natural language processing?",
+ "answer": "NLP is a field of AI focused on language understanding",
+ }
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.insert_app_annotation_directly(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+
+class TestConversationServiceExport:
+ """
+ Test conversation export/retrieval operations.
+
+ Tests retrieving conversation data for export purposes.
+ """
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_success(self, mock_db_session):
+ """Test successful retrieval of conversation."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock(
+ app_id=app_model.id, from_account_id=user.id, from_source="console"
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = conversation
+
+ # Act
+ result = ConversationService.get_conversation(app_model=app_model, conversation_id=conversation.id, user=user)
+
+ # Assert
+ assert result == conversation
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_not_found(self, mock_db_session):
+ """Test ConversationNotExistsError when conversation doesn't exist."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = None
+
+ # Act & Assert
+ with pytest.raises(ConversationNotExistsError):
+ ConversationService.get_conversation(app_model=app_model, conversation_id="non-existent", user=user)
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_export_annotation_list(self, mock_current_account, mock_db_session):
+ """Test exporting all annotations for an app."""
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(annotation_id=f"anno-{i}", app_id=app_id)
+ for i in range(10)
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+ mock_query.first.return_value = app
+ mock_query.all.return_value = annotations
+
+ # Act
+ result = AppAnnotationService.export_annotation_list_by_app_id(app_id)
+
+ # Assert
+ assert len(result) == 10
+ assert result == annotations
+
+ @patch("services.message_service.db.session")
+ def test_get_message_success(self, mock_db_session):
+ """Test successful retrieval of a message."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ message = ConversationServiceTestDataFactory.create_message_mock(
+ app_id=app_model.id, from_account_id=user.id, from_source="console"
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = message
+
+ # Act
+ result = MessageService.get_message(app_model=app_model, user=user, message_id=message.id)
+
+ # Assert
+ assert result == message
+
+ @patch("services.message_service.db.session")
+ def test_get_message_not_found(self, mock_db_session):
+ """Test MessageNotExistsError when message doesn't exist."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = None
+
+ # Act & Assert
+ with pytest.raises(MessageNotExistsError):
+ MessageService.get_message(app_model=app_model, user=user, message_id="non-existent")
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_for_end_user(self, mock_db_session):
+ """
+ Test retrieving conversation created by end user via API.
+
+ End users (API) and accounts (console) have different access patterns.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ end_user = ConversationServiceTestDataFactory.create_end_user_mock()
+
+ # Conversation created by end user via API
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock(
+ app_id=app_model.id,
+ from_end_user_id=end_user.id,
+ from_source="api", # API source for end users
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = conversation
+
+ # Act
+ result = ConversationService.get_conversation(
+ app_model=app_model, conversation_id=conversation.id, user=end_user
+ )
+
+ # Assert
+ assert result == conversation
+ # Verify query filters for API source
+ mock_query.where.assert_called()
+
+ @patch("services.conversation_service.delete_conversation_related_data") # Mock Celery task
+ @patch("services.conversation_service.db.session") # Mock database session
+ def test_delete_conversation(self, mock_db_session, mock_delete_task):
+ """
+ Test conversation deletion with async cleanup.
+
+ Deletion is a two-step process:
+ 1. Immediately delete the conversation record from database
+ 2. Trigger async background task to clean up related data
+ (messages, annotations, vector embeddings, file uploads)
+ """
+ # Arrange - Set up test data
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation_id = "conv-to-delete"
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by conversation_id
+
+ # Act - Delete the conversation
+ ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
+
+ # Assert - Verify two-step deletion process
+ # Step 1: Immediate database deletion
+ mock_query.delete.assert_called_once() # DELETE query executed
+ mock_db_session.commit.assert_called_once() # Transaction committed
+
+ # Step 2: Async cleanup task triggered
+ # The Celery task will handle cleanup of messages, annotations, etc.
+ mock_delete_task.delay.assert_called_once_with(conversation_id)
From 766e16b26f5974d689269c14eab7dc8a0976ece8 Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 18:36:37 -0800
Subject: [PATCH 41/63] add unit tests for code node (#28717)
---
.../core/workflow/nodes/code/__init__.py | 0
.../workflow/nodes/code/code_node_spec.py | 488 ++++++++++++++++++
.../core/workflow/nodes/code/entities_spec.py | 353 +++++++++++++
3 files changed, 841 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/__init__.py b/api/tests/unit_tests/core/workflow/nodes/code/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
new file mode 100644
index 0000000000..f62c714820
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
@@ -0,0 +1,488 @@
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.variables.types import SegmentType
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.nodes.code.exc import (
+ CodeNodeError,
+ DepthLimitError,
+ OutputValidationError,
+)
+
+
+class TestCodeNodeExceptions:
+ """Test suite for code node exceptions."""
+
+ def test_code_node_error_is_value_error(self):
+ """Test CodeNodeError inherits from ValueError."""
+ error = CodeNodeError("test error")
+
+ assert isinstance(error, ValueError)
+ assert str(error) == "test error"
+
+ def test_output_validation_error_is_code_node_error(self):
+ """Test OutputValidationError inherits from CodeNodeError."""
+ error = OutputValidationError("validation failed")
+
+ assert isinstance(error, CodeNodeError)
+ assert isinstance(error, ValueError)
+ assert str(error) == "validation failed"
+
+ def test_depth_limit_error_is_code_node_error(self):
+ """Test DepthLimitError inherits from CodeNodeError."""
+ error = DepthLimitError("depth exceeded")
+
+ assert isinstance(error, CodeNodeError)
+ assert isinstance(error, ValueError)
+ assert str(error) == "depth exceeded"
+
+ def test_code_node_error_with_empty_message(self):
+ """Test CodeNodeError with empty message."""
+ error = CodeNodeError("")
+
+ assert str(error) == ""
+
+ def test_output_validation_error_with_field_info(self):
+ """Test OutputValidationError with field information."""
+ error = OutputValidationError("Output 'result' is not a valid type")
+
+ assert "result" in str(error)
+ assert "not a valid type" in str(error)
+
+ def test_depth_limit_error_with_limit_info(self):
+ """Test DepthLimitError with limit information."""
+ error = DepthLimitError("Depth limit 5 reached, object too deep")
+
+ assert "5" in str(error)
+ assert "too deep" in str(error)
+
+
+class TestCodeNodeClassMethods:
+ """Test suite for CodeNode class methods."""
+
+ def test_code_node_version(self):
+ """Test CodeNode version method."""
+ version = CodeNode.version()
+
+ assert version == "1"
+
+ def test_get_default_config_python3(self):
+ """Test get_default_config for Python3."""
+ config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.PYTHON3})
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_javascript(self):
+ """Test get_default_config for JavaScript."""
+ config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.JAVASCRIPT})
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_no_filters(self):
+ """Test get_default_config with no filters defaults to Python3."""
+ config = CodeNode.get_default_config()
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_empty_filters(self):
+ """Test get_default_config with empty filters."""
+ config = CodeNode.get_default_config(filters={})
+
+ assert config is not None
+
+
+class TestCodeNodeCheckMethods:
+ """Test suite for CodeNode check methods."""
+
+ def test_check_string_none_value(self):
+ """Test _check_string with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string(None, "test_var")
+
+ assert result is None
+
+ def test_check_string_removes_null_bytes(self):
+ """Test _check_string removes null bytes."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("hello\x00world", "test_var")
+
+ assert result == "helloworld"
+ assert "\x00" not in result
+
+ def test_check_string_valid_string(self):
+ """Test _check_string with valid string."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("valid string", "test_var")
+
+ assert result == "valid string"
+
+ def test_check_string_empty_string(self):
+ """Test _check_string with empty string."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("", "test_var")
+
+ assert result == ""
+
+ def test_check_string_with_unicode(self):
+ """Test _check_string with unicode characters."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("你好世界🌍", "test_var")
+
+ assert result == "你好世界🌍"
+
+ def test_check_boolean_none_value(self):
+ """Test _check_boolean with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(None, "test_var")
+
+ assert result is None
+
+ def test_check_boolean_true_value(self):
+ """Test _check_boolean with True value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(True, "test_var")
+
+ assert result is True
+
+ def test_check_boolean_false_value(self):
+ """Test _check_boolean with False value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(False, "test_var")
+
+ assert result is False
+
+ def test_check_number_none_value(self):
+ """Test _check_number with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(None, "test_var")
+
+ assert result is None
+
+ def test_check_number_integer_value(self):
+ """Test _check_number with integer value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(42, "test_var")
+
+ assert result == 42
+
+ def test_check_number_float_value(self):
+ """Test _check_number with float value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(3.14, "test_var")
+
+ assert result == 3.14
+
+ def test_check_number_zero(self):
+ """Test _check_number with zero."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(0, "test_var")
+
+ assert result == 0
+
+ def test_check_number_negative(self):
+ """Test _check_number with negative number."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(-100, "test_var")
+
+ assert result == -100
+
+ def test_check_number_negative_float(self):
+ """Test _check_number with negative float."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(-3.14159, "test_var")
+
+ assert result == -3.14159
+
+
+class TestCodeNodeConvertBooleanToInt:
+ """Test suite for _convert_boolean_to_int static method."""
+
+ def test_convert_none_returns_none(self):
+ """Test converting None returns None."""
+ result = CodeNode._convert_boolean_to_int(None)
+
+ assert result is None
+
+ def test_convert_true_returns_one(self):
+ """Test converting True returns 1."""
+ result = CodeNode._convert_boolean_to_int(True)
+
+ assert result == 1
+ assert isinstance(result, int)
+
+ def test_convert_false_returns_zero(self):
+ """Test converting False returns 0."""
+ result = CodeNode._convert_boolean_to_int(False)
+
+ assert result == 0
+ assert isinstance(result, int)
+
+ def test_convert_integer_returns_same(self):
+ """Test converting integer returns same value."""
+ result = CodeNode._convert_boolean_to_int(42)
+
+ assert result == 42
+
+ def test_convert_float_returns_same(self):
+ """Test converting float returns same value."""
+ result = CodeNode._convert_boolean_to_int(3.14)
+
+ assert result == 3.14
+
+ def test_convert_zero_returns_zero(self):
+ """Test converting zero returns zero."""
+ result = CodeNode._convert_boolean_to_int(0)
+
+ assert result == 0
+
+ def test_convert_negative_returns_same(self):
+ """Test converting negative number returns same value."""
+ result = CodeNode._convert_boolean_to_int(-100)
+
+ assert result == -100
+
+
+class TestCodeNodeExtractVariableSelector:
+ """Test suite for _extract_variable_selector_to_variable_mapping."""
+
+ def test_extract_empty_variables(self):
+ """Test extraction with no variables."""
+ node_data = {
+ "title": "Test",
+ "variables": [],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_1",
+ node_data=node_data,
+ )
+
+ assert result == {}
+
+ def test_extract_single_variable(self):
+ """Test extraction with single variable."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "input_text", "value_selector": ["start", "text"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_1",
+ node_data=node_data,
+ )
+
+ assert "node_1.input_text" in result
+ assert result["node_1.input_text"] == ["start", "text"]
+
+ def test_extract_multiple_variables(self):
+ """Test extraction with multiple variables."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "var1", "value_selector": ["node_a", "output1"]},
+ {"variable": "var2", "value_selector": ["node_b", "output2"]},
+ {"variable": "var3", "value_selector": ["node_c", "output3"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="code_node",
+ node_data=node_data,
+ )
+
+ assert len(result) == 3
+ assert "code_node.var1" in result
+ assert "code_node.var2" in result
+ assert "code_node.var3" in result
+
+ def test_extract_with_nested_selector(self):
+ """Test extraction with nested value selector."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "deep_var", "value_selector": ["node", "obj", "nested", "value"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_x",
+ node_data=node_data,
+ )
+
+ assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
+
+
+class TestCodeNodeDataValidation:
+ """Test suite for CodeNodeData validation scenarios."""
+
+ def test_valid_python3_code_node_data(self):
+ """Test valid Python3 CodeNodeData."""
+ data = CodeNodeData(
+ title="Python Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'result': 1}",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.code_language == CodeLanguage.PYTHON3
+
+ def test_valid_javascript_code_node_data(self):
+ """Test valid JavaScript CodeNodeData."""
+ data = CodeNodeData(
+ title="JS Code",
+ variables=[],
+ code_language=CodeLanguage.JAVASCRIPT,
+ code="function main() { return { result: 1 }; }",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.code_language == CodeLanguage.JAVASCRIPT
+
+ def test_code_node_data_with_all_output_types(self):
+ """Test CodeNodeData with all valid output types."""
+ data = CodeNodeData(
+ title="All Types",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={
+ "str_out": CodeNodeData.Output(type=SegmentType.STRING),
+ "num_out": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "bool_out": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ "obj_out": CodeNodeData.Output(type=SegmentType.OBJECT),
+ "arr_str": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ "arr_num": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER),
+ "arr_bool": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN),
+ "arr_obj": CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT),
+ },
+ )
+
+ assert len(data.outputs) == 8
+
+ def test_code_node_data_complex_nested_output(self):
+ """Test CodeNodeData with complex nested output structure."""
+ data = CodeNodeData(
+ title="Complex Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={
+ "response": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "data": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ "count": CodeNodeData.Output(type=SegmentType.NUMBER),
+ },
+ ),
+ "status": CodeNodeData.Output(type=SegmentType.STRING),
+ "success": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["response"].type == SegmentType.OBJECT
+ assert data.outputs["response"].children is not None
+ assert "data" in data.outputs["response"].children
+ assert data.outputs["response"].children["data"].children is not None
+
+
+class TestCodeNodeInitialization:
+ """Test suite for CodeNode initialization methods."""
+
+ def test_init_node_data_python3(self):
+ """Test init_node_data with Python3 configuration."""
+ node = CodeNode.__new__(CodeNode)
+ data = {
+ "title": "Test Node",
+ "variables": [],
+ "code_language": "python3",
+ "code": "def main(): return {'x': 1}",
+ "outputs": {"x": {"type": "number"}},
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.title == "Test Node"
+ assert node._node_data.code_language == CodeLanguage.PYTHON3
+
+ def test_init_node_data_javascript(self):
+ """Test init_node_data with JavaScript configuration."""
+ node = CodeNode.__new__(CodeNode)
+ data = {
+ "title": "JS Node",
+ "variables": [],
+ "code_language": "javascript",
+ "code": "function main() { return { x: 1 }; }",
+ "outputs": {"x": {"type": "number"}},
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.code_language == CodeLanguage.JAVASCRIPT
+
+ def test_get_title(self):
+ """Test _get_title method."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="My Code Node",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert node._get_title() == "My Code Node"
+
+ def test_get_description_none(self):
+ """Test _get_description returns None when not set."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="Test",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert node._get_description() is None
+
+ def test_get_base_node_data(self):
+ """Test get_base_node_data returns node data."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="Base Test",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ result = node.get_base_node_data()
+
+ assert result == node._node_data
+ assert result.title == "Base Test"
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
new file mode 100644
index 0000000000..d14a6ea69c
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
@@ -0,0 +1,353 @@
+import pytest
+from pydantic import ValidationError
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.variables.types import SegmentType
+from core.workflow.nodes.code.entities import CodeNodeData
+
+
+class TestCodeNodeDataOutput:
+ """Test suite for CodeNodeData.Output model."""
+
+ def test_output_with_string_type(self):
+ """Test Output with STRING type."""
+ output = CodeNodeData.Output(type=SegmentType.STRING)
+
+ assert output.type == SegmentType.STRING
+ assert output.children is None
+
+ def test_output_with_number_type(self):
+ """Test Output with NUMBER type."""
+ output = CodeNodeData.Output(type=SegmentType.NUMBER)
+
+ assert output.type == SegmentType.NUMBER
+ assert output.children is None
+
+ def test_output_with_boolean_type(self):
+ """Test Output with BOOLEAN type."""
+ output = CodeNodeData.Output(type=SegmentType.BOOLEAN)
+
+ assert output.type == SegmentType.BOOLEAN
+
+ def test_output_with_object_type(self):
+ """Test Output with OBJECT type."""
+ output = CodeNodeData.Output(type=SegmentType.OBJECT)
+
+ assert output.type == SegmentType.OBJECT
+
+ def test_output_with_array_string_type(self):
+ """Test Output with ARRAY_STRING type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING)
+
+ assert output.type == SegmentType.ARRAY_STRING
+
+ def test_output_with_array_number_type(self):
+ """Test Output with ARRAY_NUMBER type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)
+
+ assert output.type == SegmentType.ARRAY_NUMBER
+
+ def test_output_with_array_object_type(self):
+ """Test Output with ARRAY_OBJECT type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT)
+
+ assert output.type == SegmentType.ARRAY_OBJECT
+
+ def test_output_with_array_boolean_type(self):
+ """Test Output with ARRAY_BOOLEAN type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)
+
+ assert output.type == SegmentType.ARRAY_BOOLEAN
+
+ def test_output_with_nested_children(self):
+ """Test Output with nested children for OBJECT type."""
+ child_output = CodeNodeData.Output(type=SegmentType.STRING)
+ parent_output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"name": child_output},
+ )
+
+ assert parent_output.type == SegmentType.OBJECT
+ assert parent_output.children is not None
+ assert "name" in parent_output.children
+ assert parent_output.children["name"].type == SegmentType.STRING
+
+ def test_output_with_deeply_nested_children(self):
+ """Test Output with deeply nested children."""
+ inner_child = CodeNodeData.Output(type=SegmentType.NUMBER)
+ middle_child = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"value": inner_child},
+ )
+ outer_output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"nested": middle_child},
+ )
+
+ assert outer_output.children is not None
+ assert outer_output.children["nested"].children is not None
+ assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER
+
+ def test_output_with_multiple_children(self):
+ """Test Output with multiple children."""
+ output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "age": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "active": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ },
+ )
+
+ assert output.children is not None
+ assert len(output.children) == 3
+ assert output.children["name"].type == SegmentType.STRING
+ assert output.children["age"].type == SegmentType.NUMBER
+ assert output.children["active"].type == SegmentType.BOOLEAN
+
+ def test_output_rejects_invalid_type(self):
+ """Test Output rejects invalid segment types."""
+ with pytest.raises(ValidationError):
+ CodeNodeData.Output(type=SegmentType.FILE)
+
+ def test_output_rejects_array_file_type(self):
+ """Test Output rejects ARRAY_FILE type."""
+ with pytest.raises(ValidationError):
+ CodeNodeData.Output(type=SegmentType.ARRAY_FILE)
+
+
+class TestCodeNodeDataDependency:
+ """Test suite for CodeNodeData.Dependency model."""
+
+ def test_dependency_basic(self):
+ """Test Dependency with name and version."""
+ dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0")
+
+ assert dependency.name == "numpy"
+ assert dependency.version == "1.24.0"
+
+ def test_dependency_with_complex_version(self):
+ """Test Dependency with complex version string."""
+ dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0")
+
+ assert dependency.name == "pandas"
+ assert dependency.version == ">=2.0.0,<3.0.0"
+
+ def test_dependency_with_empty_version(self):
+ """Test Dependency with empty version."""
+ dependency = CodeNodeData.Dependency(name="requests", version="")
+
+ assert dependency.name == "requests"
+ assert dependency.version == ""
+
+
+class TestCodeNodeData:
+ """Test suite for CodeNodeData model."""
+
+ def test_code_node_data_python3(self):
+ """Test CodeNodeData with Python3 language."""
+ data = CodeNodeData(
+ title="Test Code Node",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'result': 42}",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.title == "Test Code Node"
+ assert data.code_language == CodeLanguage.PYTHON3
+ assert data.code == "def main(): return {'result': 42}"
+ assert "result" in data.outputs
+ assert data.dependencies is None
+
+ def test_code_node_data_javascript(self):
+ """Test CodeNodeData with JavaScript language."""
+ data = CodeNodeData(
+ title="JS Code Node",
+ variables=[],
+ code_language=CodeLanguage.JAVASCRIPT,
+ code="function main() { return { result: 'hello' }; }",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert data.code_language == CodeLanguage.JAVASCRIPT
+ assert "result" in data.outputs
+ assert data.outputs["result"].type == SegmentType.STRING
+
+ def test_code_node_data_with_dependencies(self):
+ """Test CodeNodeData with dependencies."""
+ data = CodeNodeData(
+ title="Code with Deps",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="import numpy as np\ndef main(): return {'sum': 10}",
+ outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ dependencies=[
+ CodeNodeData.Dependency(name="numpy", version="1.24.0"),
+ CodeNodeData.Dependency(name="pandas", version="2.0.0"),
+ ],
+ )
+
+ assert data.dependencies is not None
+ assert len(data.dependencies) == 2
+ assert data.dependencies[0].name == "numpy"
+ assert data.dependencies[1].name == "pandas"
+
+ def test_code_node_data_with_multiple_outputs(self):
+ """Test CodeNodeData with multiple outputs."""
+ data = CodeNodeData(
+ title="Multi Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}",
+ outputs={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "count": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ },
+ )
+
+ assert len(data.outputs) == 3
+ assert data.outputs["name"].type == SegmentType.STRING
+ assert data.outputs["count"].type == SegmentType.NUMBER
+ assert data.outputs["items"].type == SegmentType.ARRAY_STRING
+
+ def test_code_node_data_with_object_output(self):
+ """Test CodeNodeData with nested object output."""
+ data = CodeNodeData(
+ title="Object Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'user': {'name': 'John', 'age': 30}}",
+ outputs={
+ "user": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "age": CodeNodeData.Output(type=SegmentType.NUMBER),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["user"].type == SegmentType.OBJECT
+ assert data.outputs["user"].children is not None
+ assert len(data.outputs["user"].children) == 2
+
+ def test_code_node_data_with_array_object_output(self):
+ """Test CodeNodeData with array of objects output."""
+ data = CodeNodeData(
+ title="Array Object Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}",
+ outputs={
+ "users": CodeNodeData.Output(
+ type=SegmentType.ARRAY_OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT
+ assert data.outputs["users"].children is not None
+
+ def test_code_node_data_empty_code(self):
+ """Test CodeNodeData with empty code."""
+ data = CodeNodeData(
+ title="Empty Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert data.code == ""
+ assert len(data.outputs) == 0
+
+ def test_code_node_data_multiline_code(self):
+ """Test CodeNodeData with multiline code."""
+ multiline_code = """
+def main():
+ result = 0
+ for i in range(10):
+ result += i
+ return {'sum': result}
+"""
+ data = CodeNodeData(
+ title="Multiline Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=multiline_code,
+ outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert "for i in range(10)" in data.code
+ assert "result += i" in data.code
+
+ def test_code_node_data_with_special_characters_in_code(self):
+ """Test CodeNodeData with special characters in code."""
+ code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}"
+ data = CodeNodeData(
+ title="Special Chars",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=code_with_special,
+ outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert "\\n" in data.code
+ assert "\\t" in data.code
+
+ def test_code_node_data_with_unicode_in_code(self):
+ """Test CodeNodeData with unicode characters in code."""
+ unicode_code = "def main(): return {'greeting': '你好世界'}"
+ data = CodeNodeData(
+ title="Unicode Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=unicode_code,
+ outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert "你好世界" in data.code
+
+ def test_code_node_data_empty_dependencies_list(self):
+ """Test CodeNodeData with empty dependencies list."""
+ data = CodeNodeData(
+ title="No Deps",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={},
+ dependencies=[],
+ )
+
+ assert data.dependencies is not None
+ assert len(data.dependencies) == 0
+
+ def test_code_node_data_with_boolean_array_output(self):
+ """Test CodeNodeData with boolean array output."""
+ data = CodeNodeData(
+ title="Boolean Array",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'flags': [True, False, True]}",
+ outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)},
+ )
+
+ assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN
+
+ def test_code_node_data_with_number_array_output(self):
+ """Test CodeNodeData with number array output."""
+ data = CodeNodeData(
+ title="Number Array",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'values': [1, 2, 3, 4, 5]}",
+ outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)},
+ )
+
+ assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER
From 5815950092b93cecc69b89f0c84f23e5a9604cc6 Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 18:36:47 -0800
Subject: [PATCH 42/63] add unit tests for iteration node (#28719)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../core/workflow/nodes/iteration/__init__.py | 0
.../workflow/nodes/iteration/entities_spec.py | 339 +++++++++++++++
.../nodes/iteration/iteration_node_spec.py | 390 ++++++++++++++++++
3 files changed, 729 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
new file mode 100644
index 0000000000..d669cc7465
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
@@ -0,0 +1,339 @@
+from core.workflow.nodes.iteration.entities import (
+ ErrorHandleMode,
+ IterationNodeData,
+ IterationStartNodeData,
+ IterationState,
+)
+
+
+class TestErrorHandleMode:
+ """Test suite for ErrorHandleMode enum."""
+
+ def test_terminated_value(self):
+ """Test TERMINATED enum value."""
+ assert ErrorHandleMode.TERMINATED == "terminated"
+ assert ErrorHandleMode.TERMINATED.value == "terminated"
+
+ def test_continue_on_error_value(self):
+ """Test CONTINUE_ON_ERROR enum value."""
+ assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
+ assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error"
+
+ def test_remove_abnormal_output_value(self):
+ """Test REMOVE_ABNORMAL_OUTPUT enum value."""
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output"
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output"
+
+ def test_error_handle_mode_is_str_enum(self):
+ """Test ErrorHandleMode is a string enum."""
+ assert isinstance(ErrorHandleMode.TERMINATED, str)
+ assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str)
+ assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str)
+
+ def test_error_handle_mode_comparison(self):
+ """Test ErrorHandleMode can be compared with strings."""
+ assert ErrorHandleMode.TERMINATED == "terminated"
+ assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
+
+ def test_all_error_handle_modes(self):
+ """Test all ErrorHandleMode values are accessible."""
+ modes = list(ErrorHandleMode)
+
+ assert len(modes) == 3
+ assert ErrorHandleMode.TERMINATED in modes
+ assert ErrorHandleMode.CONTINUE_ON_ERROR in modes
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes
+
+
+class TestIterationNodeData:
+ """Test suite for IterationNodeData model."""
+
+ def test_iteration_node_data_basic(self):
+ """Test IterationNodeData with basic configuration."""
+ data = IterationNodeData(
+ title="Test Iteration",
+ iterator_selector=["node1", "output"],
+ output_selector=["iteration", "result"],
+ )
+
+ assert data.title == "Test Iteration"
+ assert data.iterator_selector == ["node1", "output"]
+ assert data.output_selector == ["iteration", "result"]
+
+ def test_iteration_node_data_default_values(self):
+ """Test IterationNodeData default values."""
+ data = IterationNodeData(
+ title="Default Test",
+ iterator_selector=["start", "items"],
+ output_selector=["iter", "out"],
+ )
+
+ assert data.parent_loop_id is None
+ assert data.is_parallel is False
+ assert data.parallel_nums == 10
+ assert data.error_handle_mode == ErrorHandleMode.TERMINATED
+ assert data.flatten_output is True
+
+ def test_iteration_node_data_parallel_mode(self):
+ """Test IterationNodeData with parallel mode enabled."""
+ data = IterationNodeData(
+ title="Parallel Iteration",
+ iterator_selector=["node", "list"],
+ output_selector=["iter", "output"],
+ is_parallel=True,
+ parallel_nums=5,
+ )
+
+ assert data.is_parallel is True
+ assert data.parallel_nums == 5
+
+ def test_iteration_node_data_custom_parallel_nums(self):
+ """Test IterationNodeData with custom parallel numbers."""
+ data = IterationNodeData(
+ title="Custom Parallel",
+ iterator_selector=["a", "b"],
+ output_selector=["c", "d"],
+ parallel_nums=20,
+ )
+
+ assert data.parallel_nums == 20
+
+ def test_iteration_node_data_continue_on_error(self):
+ """Test IterationNodeData with continue on error mode."""
+ data = IterationNodeData(
+ title="Continue Error",
+ iterator_selector=["x", "y"],
+ output_selector=["z", "w"],
+ error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
+ )
+
+ assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+
+ def test_iteration_node_data_remove_abnormal_output(self):
+ """Test IterationNodeData with remove abnormal output mode."""
+ data = IterationNodeData(
+ title="Remove Abnormal",
+ iterator_selector=["input", "array"],
+ output_selector=["output", "result"],
+ error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
+ )
+
+ assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
+
+ def test_iteration_node_data_flatten_output_disabled(self):
+ """Test IterationNodeData with flatten output disabled."""
+ data = IterationNodeData(
+ title="No Flatten",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=False,
+ )
+
+ assert data.flatten_output is False
+
+ def test_iteration_node_data_with_parent_loop_id(self):
+ """Test IterationNodeData with parent loop ID."""
+ data = IterationNodeData(
+ title="Nested Loop",
+ iterator_selector=["parent", "items"],
+ output_selector=["child", "output"],
+ parent_loop_id="parent_loop_123",
+ )
+
+ assert data.parent_loop_id == "parent_loop_123"
+
+ def test_iteration_node_data_complex_selectors(self):
+ """Test IterationNodeData with complex selectors."""
+ data = IterationNodeData(
+ title="Complex Selectors",
+ iterator_selector=["node1", "output", "data", "items"],
+ output_selector=["iteration", "result", "value"],
+ )
+
+ assert len(data.iterator_selector) == 4
+ assert len(data.output_selector) == 3
+
+ def test_iteration_node_data_all_options(self):
+ """Test IterationNodeData with all options configured."""
+ data = IterationNodeData(
+ title="Full Config",
+ iterator_selector=["start", "list"],
+ output_selector=["end", "result"],
+ parent_loop_id="outer_loop",
+ is_parallel=True,
+ parallel_nums=15,
+ error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
+ flatten_output=False,
+ )
+
+ assert data.title == "Full Config"
+ assert data.parent_loop_id == "outer_loop"
+ assert data.is_parallel is True
+ assert data.parallel_nums == 15
+ assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+ assert data.flatten_output is False
+
+
+class TestIterationStartNodeData:
+ """Test suite for IterationStartNodeData model."""
+
+ def test_iteration_start_node_data_basic(self):
+ """Test IterationStartNodeData basic creation."""
+ data = IterationStartNodeData(title="Iteration Start")
+
+ assert data.title == "Iteration Start"
+
+ def test_iteration_start_node_data_with_description(self):
+ """Test IterationStartNodeData with description."""
+ data = IterationStartNodeData(
+ title="Start Node",
+ desc="This is the start of iteration",
+ )
+
+ assert data.title == "Start Node"
+ assert data.desc == "This is the start of iteration"
+
+
+class TestIterationState:
+ """Test suite for IterationState model."""
+
+ def test_iteration_state_default_values(self):
+ """Test IterationState default values."""
+ state = IterationState()
+
+ assert state.outputs == []
+ assert state.current_output is None
+
+ def test_iteration_state_with_outputs(self):
+ """Test IterationState with outputs."""
+ state = IterationState(outputs=["result1", "result2", "result3"])
+
+ assert len(state.outputs) == 3
+ assert state.outputs[0] == "result1"
+ assert state.outputs[2] == "result3"
+
+ def test_iteration_state_with_current_output(self):
+ """Test IterationState with current output."""
+ state = IterationState(current_output="current_value")
+
+ assert state.current_output == "current_value"
+
+ def test_iteration_state_get_last_output_with_outputs(self):
+ """Test get_last_output with outputs present."""
+ state = IterationState(outputs=["first", "second", "last"])
+
+ result = state.get_last_output()
+
+ assert result == "last"
+
+ def test_iteration_state_get_last_output_empty(self):
+ """Test get_last_output with empty outputs."""
+ state = IterationState(outputs=[])
+
+ result = state.get_last_output()
+
+ assert result is None
+
+ def test_iteration_state_get_last_output_single(self):
+ """Test get_last_output with single output."""
+ state = IterationState(outputs=["only_one"])
+
+ result = state.get_last_output()
+
+ assert result == "only_one"
+
+ def test_iteration_state_get_current_output(self):
+ """Test get_current_output method."""
+ state = IterationState(current_output={"key": "value"})
+
+ result = state.get_current_output()
+
+ assert result == {"key": "value"}
+
+ def test_iteration_state_get_current_output_none(self):
+ """Test get_current_output when None."""
+ state = IterationState()
+
+ result = state.get_current_output()
+
+ assert result is None
+
+ def test_iteration_state_with_complex_outputs(self):
+ """Test IterationState with complex output types."""
+ state = IterationState(
+ outputs=[
+ {"id": 1, "name": "first"},
+ {"id": 2, "name": "second"},
+ [1, 2, 3],
+ "string_output",
+ ]
+ )
+
+ assert len(state.outputs) == 4
+ assert state.outputs[0] == {"id": 1, "name": "first"}
+ assert state.outputs[2] == [1, 2, 3]
+
+ def test_iteration_state_with_none_outputs(self):
+ """Test IterationState with None values in outputs."""
+ state = IterationState(outputs=["value1", None, "value3"])
+
+ assert len(state.outputs) == 3
+ assert state.outputs[1] is None
+
+ def test_iteration_state_get_last_output_with_none(self):
+ """Test get_last_output when last output is None."""
+ state = IterationState(outputs=["first", None])
+
+ result = state.get_last_output()
+
+ assert result is None
+
+ def test_iteration_state_metadata_class(self):
+ """Test IterationState.MetaData class."""
+ metadata = IterationState.MetaData(iterator_length=10)
+
+ assert metadata.iterator_length == 10
+
+ def test_iteration_state_metadata_different_lengths(self):
+ """Test IterationState.MetaData with different lengths."""
+ metadata1 = IterationState.MetaData(iterator_length=0)
+ metadata2 = IterationState.MetaData(iterator_length=100)
+ metadata3 = IterationState.MetaData(iterator_length=1000000)
+
+ assert metadata1.iterator_length == 0
+ assert metadata2.iterator_length == 100
+ assert metadata3.iterator_length == 1000000
+
+ def test_iteration_state_outputs_modification(self):
+ """Test modifying IterationState outputs."""
+ state = IterationState(outputs=[])
+
+ state.outputs.append("new_output")
+ state.outputs.append("another_output")
+
+ assert len(state.outputs) == 2
+ assert state.get_last_output() == "another_output"
+
+ def test_iteration_state_current_output_update(self):
+ """Test updating current_output."""
+ state = IterationState()
+
+ state.current_output = "first_value"
+ assert state.get_current_output() == "first_value"
+
+ state.current_output = "updated_value"
+ assert state.get_current_output() == "updated_value"
+
+ def test_iteration_state_with_numeric_outputs(self):
+ """Test IterationState with numeric outputs."""
+ state = IterationState(outputs=[1, 2, 3, 4, 5])
+
+ assert state.get_last_output() == 5
+ assert len(state.outputs) == 5
+
+ def test_iteration_state_with_boolean_outputs(self):
+ """Test IterationState with boolean outputs."""
+ state = IterationState(outputs=[True, False, True])
+
+ assert state.get_last_output() is True
+ assert state.outputs[1] is False
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
new file mode 100644
index 0000000000..51af4367f7
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
@@ -0,0 +1,390 @@
+from core.workflow.enums import NodeType
+from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
+from core.workflow.nodes.iteration.exc import (
+ InvalidIteratorValueError,
+ IterationGraphNotFoundError,
+ IterationIndexNotFoundError,
+ IterationNodeError,
+ IteratorVariableNotFoundError,
+ StartNodeIdNotFoundError,
+)
+from core.workflow.nodes.iteration.iteration_node import IterationNode
+
+
+class TestIterationNodeExceptions:
+ """Test suite for iteration node exceptions."""
+
+ def test_iteration_node_error_is_value_error(self):
+ """Test IterationNodeError inherits from ValueError."""
+ error = IterationNodeError("test error")
+
+ assert isinstance(error, ValueError)
+ assert str(error) == "test error"
+
+ def test_iterator_variable_not_found_error(self):
+ """Test IteratorVariableNotFoundError."""
+ error = IteratorVariableNotFoundError("Iterator variable not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert isinstance(error, ValueError)
+ assert "Iterator variable not found" in str(error)
+
+ def test_invalid_iterator_value_error(self):
+ """Test InvalidIteratorValueError."""
+ error = InvalidIteratorValueError("Invalid iterator value")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Invalid iterator value" in str(error)
+
+ def test_start_node_id_not_found_error(self):
+ """Test StartNodeIdNotFoundError."""
+ error = StartNodeIdNotFoundError("Start node ID not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Start node ID not found" in str(error)
+
+ def test_iteration_graph_not_found_error(self):
+ """Test IterationGraphNotFoundError."""
+ error = IterationGraphNotFoundError("Iteration graph not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Iteration graph not found" in str(error)
+
+ def test_iteration_index_not_found_error(self):
+ """Test IterationIndexNotFoundError."""
+ error = IterationIndexNotFoundError("Iteration index not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Iteration index not found" in str(error)
+
+ def test_exception_with_empty_message(self):
+ """Test exception with empty message."""
+ error = IterationNodeError("")
+
+ assert str(error) == ""
+
+ def test_exception_with_detailed_message(self):
+ """Test exception with detailed message."""
+ error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'")
+
+ assert "items" in str(error)
+ assert "start_node" in str(error)
+
+ def test_all_exceptions_inherit_from_base(self):
+ """Test all exceptions inherit from IterationNodeError."""
+ exceptions = [
+ IteratorVariableNotFoundError("test"),
+ InvalidIteratorValueError("test"),
+ StartNodeIdNotFoundError("test"),
+ IterationGraphNotFoundError("test"),
+ IterationIndexNotFoundError("test"),
+ ]
+
+ for exc in exceptions:
+ assert isinstance(exc, IterationNodeError)
+ assert isinstance(exc, ValueError)
+
+
+class TestIterationNodeClassAttributes:
+ """Test suite for IterationNode class attributes."""
+
+ def test_node_type(self):
+ """Test IterationNode node_type attribute."""
+ assert IterationNode.node_type == NodeType.ITERATION
+
+ def test_version(self):
+ """Test IterationNode version method."""
+ version = IterationNode.version()
+
+ assert version == "1"
+
+
+class TestIterationNodeDefaultConfig:
+ """Test suite for IterationNode get_default_config."""
+
+ def test_get_default_config_returns_dict(self):
+ """Test get_default_config returns a dictionary."""
+ config = IterationNode.get_default_config()
+
+ assert isinstance(config, dict)
+
+ def test_get_default_config_type(self):
+ """Test get_default_config includes type."""
+ config = IterationNode.get_default_config()
+
+ assert config.get("type") == "iteration"
+
+ def test_get_default_config_has_config_section(self):
+ """Test get_default_config has config section."""
+ config = IterationNode.get_default_config()
+
+ assert "config" in config
+ assert isinstance(config["config"], dict)
+
+ def test_get_default_config_is_parallel_default(self):
+ """Test get_default_config is_parallel default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["is_parallel"] is False
+
+ def test_get_default_config_parallel_nums_default(self):
+ """Test get_default_config parallel_nums default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["parallel_nums"] == 10
+
+ def test_get_default_config_error_handle_mode_default(self):
+ """Test get_default_config error_handle_mode default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED
+
+ def test_get_default_config_flatten_output_default(self):
+ """Test get_default_config flatten_output default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["flatten_output"] is True
+
+ def test_get_default_config_with_none_filters(self):
+ """Test get_default_config with None filters."""
+ config = IterationNode.get_default_config(filters=None)
+
+ assert config is not None
+ assert "type" in config
+
+ def test_get_default_config_with_empty_filters(self):
+ """Test get_default_config with empty filters."""
+ config = IterationNode.get_default_config(filters={})
+
+ assert config is not None
+
+
+class TestIterationNodeInitialization:
+ """Test suite for IterationNode initialization."""
+
+ def test_init_node_data_basic(self):
+ """Test init_node_data with basic configuration."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Test Iteration",
+ "iterator_selector": ["start", "items"],
+ "output_selector": ["iteration", "result"],
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.title == "Test Iteration"
+ assert node._node_data.iterator_selector == ["start", "items"]
+
+ def test_init_node_data_with_parallel(self):
+ """Test init_node_data with parallel configuration."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Parallel Iteration",
+ "iterator_selector": ["node", "list"],
+ "output_selector": ["out", "result"],
+ "is_parallel": True,
+ "parallel_nums": 5,
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.is_parallel is True
+ assert node._node_data.parallel_nums == 5
+
+ def test_init_node_data_with_error_handle_mode(self):
+ """Test init_node_data with error handle mode."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Error Handle Test",
+ "iterator_selector": ["a", "b"],
+ "output_selector": ["c", "d"],
+ "error_handle_mode": "continue-on-error",
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+
+ def test_get_title(self):
+ """Test _get_title method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="My Iteration",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ )
+
+ assert node._get_title() == "My Iteration"
+
+ def test_get_description_none(self):
+ """Test _get_description returns None when not set."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ assert node._get_description() is None
+
+ def test_get_description_with_value(self):
+ """Test _get_description with value."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ desc="This is a description",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ assert node._get_description() == "This is a description"
+
+ def test_get_base_node_data(self):
+ """Test get_base_node_data returns node data."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Base Test",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ )
+
+ result = node.get_base_node_data()
+
+ assert result == node._node_data
+
+
+class TestIterationNodeDataValidation:
+ """Test suite for IterationNodeData validation scenarios."""
+
+ def test_valid_iteration_node_data(self):
+ """Test valid IterationNodeData creation."""
+ data = IterationNodeData(
+ title="Valid Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["end", "result"],
+ )
+
+ assert data.title == "Valid Iteration"
+
+ def test_iteration_node_data_with_all_error_modes(self):
+ """Test IterationNodeData with all error handle modes."""
+ modes = [
+ ErrorHandleMode.TERMINATED,
+ ErrorHandleMode.CONTINUE_ON_ERROR,
+ ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
+ ]
+
+ for mode in modes:
+ data = IterationNodeData(
+ title=f"Test {mode}",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ error_handle_mode=mode,
+ )
+ assert data.error_handle_mode == mode
+
+ def test_iteration_node_data_parallel_configuration(self):
+ """Test IterationNodeData parallel configuration combinations."""
+ configs = [
+ (False, 10),
+ (True, 1),
+ (True, 5),
+ (True, 20),
+ (True, 100),
+ ]
+
+ for is_parallel, parallel_nums in configs:
+ data = IterationNodeData(
+ title="Parallel Test",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ is_parallel=is_parallel,
+ parallel_nums=parallel_nums,
+ )
+ assert data.is_parallel == is_parallel
+ assert data.parallel_nums == parallel_nums
+
+ def test_iteration_node_data_flatten_output_options(self):
+ """Test IterationNodeData flatten_output options."""
+ data_flatten = IterationNodeData(
+ title="Flatten True",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=True,
+ )
+
+ data_no_flatten = IterationNodeData(
+ title="Flatten False",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=False,
+ )
+
+ assert data_flatten.flatten_output is True
+ assert data_no_flatten.flatten_output is False
+
+ def test_iteration_node_data_complex_selectors(self):
+ """Test IterationNodeData with complex selectors."""
+ data = IterationNodeData(
+ title="Complex",
+ iterator_selector=["node1", "output", "data", "items", "list"],
+ output_selector=["iteration", "result", "value", "final"],
+ )
+
+ assert len(data.iterator_selector) == 5
+ assert len(data.output_selector) == 4
+
+ def test_iteration_node_data_single_element_selectors(self):
+ """Test IterationNodeData with single element selectors."""
+ data = IterationNodeData(
+ title="Single",
+ iterator_selector=["items"],
+ output_selector=["result"],
+ )
+
+ assert len(data.iterator_selector) == 1
+ assert len(data.output_selector) == 1
+
+
+class TestIterationNodeErrorStrategies:
+ """Test suite for IterationNode error strategies."""
+
+ def test_get_error_strategy_default(self):
+ """Test _get_error_strategy with default value."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_error_strategy()
+
+ assert result is None or result == node._node_data.error_strategy
+
+ def test_get_retry_config(self):
+ """Test _get_retry_config method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_retry_config()
+
+ assert result is not None
+
+ def test_get_default_value_dict(self):
+ """Test _get_default_value_dict method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_default_value_dict()
+
+ assert isinstance(result, dict)
From 01afa5616652e3cdf41029b6a4e95f0742c504d1 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 21:37:24 -0500
Subject: [PATCH 43/63] chore: enhance the test script of current billing
service (#28747)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/test_billing_service.py | 1065 ++++++++++++++++-
1 file changed, 1064 insertions(+), 1 deletion(-)
diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py
index dc13143417..915aee3fa7 100644
--- a/api/tests/unit_tests/services/test_billing_service.py
+++ b/api/tests/unit_tests/services/test_billing_service.py
@@ -1,3 +1,18 @@
+"""Comprehensive unit tests for BillingService.
+
+This test module covers all aspects of the billing service including:
+- HTTP request handling with retry logic
+- Subscription tier management and billing information retrieval
+- Usage calculation and credit management (positive/negative deltas)
+- Rate limit enforcement for compliance downloads and education features
+- Account management and permission checks
+- Cache management for billing data
+- Partner integration features
+
+All tests use mocking to avoid external dependencies and ensure fast, reliable execution.
+Tests follow the Arrange-Act-Assert pattern for clarity.
+"""
+
import json
from unittest.mock import MagicMock, patch
@@ -5,11 +20,20 @@ import httpx
import pytest
from werkzeug.exceptions import InternalServerError
+from enums.cloud_plan import CloudPlan
+from models import Account, TenantAccountJoin, TenantAccountRole
from services.billing_service import BillingService
class TestBillingServiceSendRequest:
- """Unit tests for BillingService._send_request method."""
+ """Unit tests for BillingService._send_request method.
+
+ Tests cover:
+ - Successful GET/PUT/POST/DELETE requests
+ - Error handling for various HTTP status codes
+ - Retry logic on network failures
+ - Request header and parameter validation
+ """
@pytest.fixture
def mock_httpx_request(self):
@@ -234,3 +258,1042 @@ class TestBillingServiceSendRequest:
# Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts)
assert mock_httpx_request.call_count > 1
+
+
+class TestBillingServiceSubscriptionInfo:
+ """Unit tests for subscription tier and billing info retrieval.
+
+ Tests cover:
+ - Billing information retrieval
+ - Knowledge base rate limits with default and custom values
+ - Payment link generation for subscriptions and model providers
+ - Invoice retrieval
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_info_success(self, mock_send_request):
+ """Test successful retrieval of billing information."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_response = {
+ "subscription_plan": "professional",
+ "billing_cycle": "monthly",
+ "status": "active",
+ }
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_info(tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id})
+
+ def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request):
+ """Test knowledge rate limit retrieval with default values."""
+ # Arrange
+ tenant_id = "tenant-456"
+ mock_send_request.return_value = {}
+
+ # Act
+ result = BillingService.get_knowledge_rate_limit(tenant_id)
+
+ # Assert
+ assert result["limit"] == 10 # Default limit
+ assert result["subscription_plan"] == CloudPlan.SANDBOX # Default plan
+ mock_send_request.assert_called_once_with(
+ "GET", "/subscription/knowledge-rate-limit", params={"tenant_id": tenant_id}
+ )
+
+ def test_get_knowledge_rate_limit_with_custom_values(self, mock_send_request):
+ """Test knowledge rate limit retrieval with custom values."""
+ # Arrange
+ tenant_id = "tenant-789"
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+
+ # Act
+ result = BillingService.get_knowledge_rate_limit(tenant_id)
+
+ # Assert
+ assert result["limit"] == 100
+ assert result["subscription_plan"] == CloudPlan.PROFESSIONAL
+
+ def test_get_subscription_payment_link(self, mock_send_request):
+ """Test subscription payment link generation."""
+ # Arrange
+ plan = "professional"
+ interval = "monthly"
+ email = "user@example.com"
+ tenant_id = "tenant-123"
+ expected_response = {"payment_link": "https://payment.example.com/checkout"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_subscription(plan, interval, email, tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/subscription/payment-link",
+ params={"plan": plan, "interval": interval, "prefilled_email": email, "tenant_id": tenant_id},
+ )
+
+ def test_get_model_provider_payment_link(self, mock_send_request):
+ """Test model provider payment link generation."""
+ # Arrange
+ provider_name = "openai"
+ tenant_id = "tenant-123"
+ account_id = "account-456"
+ email = "user@example.com"
+ expected_response = {"payment_link": "https://payment.example.com/provider"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_model_provider_payment_link(provider_name, tenant_id, account_id, email)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/model-provider/payment-link",
+ params={
+ "provider_name": provider_name,
+ "tenant_id": tenant_id,
+ "account_id": account_id,
+ "prefilled_email": email,
+ },
+ )
+
+ def test_get_invoices(self, mock_send_request):
+ """Test invoice retrieval."""
+ # Arrange
+ email = "user@example.com"
+ tenant_id = "tenant-123"
+ expected_response = {"invoices": [{"id": "inv-1", "amount": 100}]}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_invoices(email, tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/invoices", params={"prefilled_email": email, "tenant_id": tenant_id}
+ )
+
+
+class TestBillingServiceUsageCalculation:
+ """Unit tests for usage calculation and credit management.
+
+ Tests cover:
+ - Feature plan usage information retrieval
+ - Credit addition (positive delta)
+ - Credit consumption (negative delta)
+ - Usage refunds
+ - Specific feature usage queries
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
+ """Test retrieval of tenant feature plan usage information."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
+
+ def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
+ """Test updating tenant feature usage with positive delta (adding credits)."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ delta = 10
+ expected_response = {"result": "success", "history_id": "hist-uuid-123"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ assert result["result"] == "success"
+ assert "history_id" in result
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_update_tenant_feature_plan_usage_negative_delta(self, mock_send_request):
+ """Test updating tenant feature usage with negative delta (consuming credits)."""
+ # Arrange
+ tenant_id = "tenant-456"
+ feature_key = "workflow"
+ delta = -5
+ expected_response = {"result": "success", "history_id": "hist-uuid-456"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_refund_tenant_feature_plan_usage(self, mock_send_request):
+ """Test refunding a previous usage charge."""
+ # Arrange
+ history_id = "hist-uuid-789"
+ expected_response = {"result": "success", "history_id": history_id}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.refund_tenant_feature_plan_usage(history_id)
+
+ # Assert
+ assert result == expected_response
+ assert result["result"] == "success"
+ mock_send_request.assert_called_once_with(
+ "POST", "/tenant-feature-usage/refund", params={"quota_usage_history_id": history_id}
+ )
+
+ def test_get_tenant_feature_plan_usage(self, mock_send_request):
+ """Test getting specific feature usage for a tenant."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ expected_response = {"used": 75, "limit": 100, "remaining": 25}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/billing/tenant_feature_plan/usage", params={"tenant_id": tenant_id, "feature_key": feature_key}
+ )
+
+
+class TestBillingServiceRateLimitEnforcement:
+ """Unit tests for rate limit enforcement mechanisms.
+
+ Tests cover:
+ - Compliance download rate limiting (4 requests per 60 seconds)
+ - Education verification rate limiting (10 requests per 60 seconds)
+ - Education activation rate limiting (10 requests per 60 seconds)
+ - Rate limit increment after successful operations
+ - Proper exception raising when limits are exceeded
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_compliance_download_rate_limiter_not_limited(self, mock_send_request):
+ """Test compliance download when rate limit is not exceeded."""
+ # Arrange
+ doc_name = "compliance_report.pdf"
+ account_id = "account-123"
+ tenant_id = "tenant-456"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+ expected_response = {"download_link": "https://example.com/download"}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}")
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/compliance/download",
+ json={
+ "doc_name": doc_name,
+ "account_id": account_id,
+ "tenant_id": tenant_id,
+ "ip_address": ip,
+ "device_info": device_info,
+ },
+ )
+ # Verify rate limit was incremented after successful download
+ mock_increment.assert_called_once_with(f"{account_id}:{tenant_id}")
+
+ def test_compliance_download_rate_limiter_exceeded(self, mock_send_request):
+ """Test compliance download when rate limit is exceeded."""
+ # Arrange
+ doc_name = "compliance_report.pdf"
+ account_id = "account-123"
+ tenant_id = "tenant-456"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import ComplianceRateLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(ComplianceRateLimitError):
+ BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+
+ mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}")
+ mock_send_request.assert_not_called()
+
+ def test_education_verify_rate_limit_not_exceeded(self, mock_send_request):
+ """Test education verification when rate limit is not exceeded."""
+ # Arrange
+ account_id = "account-123"
+ account_email = "student@university.edu"
+ expected_response = {"verified": True, "institution": "University"}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"
+ ) as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.verify(account_id, account_email)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(account_email)
+ mock_send_request.assert_called_once_with("GET", "/education/verify", params={"account_id": account_id})
+ mock_increment.assert_called_once_with(account_email)
+
+ def test_education_verify_rate_limit_exceeded(self, mock_send_request):
+ """Test education verification when rate limit is exceeded."""
+ # Arrange
+ account_id = "account-123"
+ account_email = "student@university.edu"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import EducationVerifyLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(EducationVerifyLimitError):
+ BillingService.EducationIdentity.verify(account_id, account_email)
+
+ mock_is_limited.assert_called_once_with(account_email)
+ mock_send_request.assert_not_called()
+
+ def test_education_activate_rate_limit_not_exceeded(self, mock_send_request):
+ """Test education activation when rate limit is not exceeded."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-123"
+ account.email = "student@university.edu"
+ account.current_tenant_id = "tenant-456"
+ token = "verification-token"
+ institution = "MIT"
+ role = "student"
+ expected_response = {"result": "success", "activated": True}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit"
+ ) as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.activate(account, token, institution, role)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(account.email)
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/education/",
+ json={"institution": institution, "token": token, "role": role},
+ params={"account_id": account.id, "curr_tenant_id": account.current_tenant_id},
+ )
+ mock_increment.assert_called_once_with(account.email)
+
+ def test_education_activate_rate_limit_exceeded(self, mock_send_request):
+ """Test education activation when rate limit is exceeded."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-123"
+ account.email = "student@university.edu"
+ account.current_tenant_id = "tenant-456"
+ token = "verification-token"
+ institution = "MIT"
+ role = "student"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import EducationActivateLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(EducationActivateLimitError):
+ BillingService.EducationIdentity.activate(account, token, institution, role)
+
+ mock_is_limited.assert_called_once_with(account.email)
+ mock_send_request.assert_not_called()
+
+
+class TestBillingServiceEducationIdentity:
+ """Unit tests for education identity verification and management.
+
+ Tests cover:
+ - Education verification status checking
+ - Institution autocomplete with pagination
+ - Default parameter handling
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_education_status(self, mock_send_request):
+ """Test checking education verification status."""
+ # Arrange
+ account_id = "account-123"
+ expected_response = {"verified": True, "institution": "MIT", "role": "student"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.status(account_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/education/status", params={"account_id": account_id})
+
+ def test_education_autocomplete(self, mock_send_request):
+ """Test education institution autocomplete."""
+ # Arrange
+ keywords = "Massachusetts"
+ page = 0
+ limit = 20
+ expected_response = {
+ "institutions": [
+ {"name": "Massachusetts Institute of Technology", "domain": "mit.edu"},
+ {"name": "University of Massachusetts", "domain": "umass.edu"},
+ ]
+ }
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.autocomplete(keywords, page, limit)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/education/autocomplete", params={"keywords": keywords, "page": page, "limit": limit}
+ )
+
+ def test_education_autocomplete_with_defaults(self, mock_send_request):
+ """Test education institution autocomplete with default parameters."""
+ # Arrange
+ keywords = "Stanford"
+ expected_response = {"institutions": [{"name": "Stanford University", "domain": "stanford.edu"}]}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.autocomplete(keywords)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/education/autocomplete", params={"keywords": keywords, "page": 0, "limit": 20}
+ )
+
+
+class TestBillingServiceAccountManagement:
+ """Unit tests for account-related billing operations.
+
+ Tests cover:
+ - Account deletion
+ - Email freeze status checking
+ - Account deletion feedback submission
+ - Tenant owner/admin permission validation
+ - Error handling for missing tenant joins
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session."""
+ with patch("services.billing_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_delete_account(self, mock_send_request):
+ """Test account deletion."""
+ # Arrange
+ account_id = "account-123"
+ expected_response = {"result": "success", "deleted": True}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.delete_account(account_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id})
+
+ def test_is_email_in_freeze_true(self, mock_send_request):
+ """Test checking if email is frozen (returns True)."""
+ # Arrange
+ email = "frozen@example.com"
+ mock_send_request.return_value = {"data": True}
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is True
+ mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email})
+
+ def test_is_email_in_freeze_false(self, mock_send_request):
+ """Test checking if email is frozen (returns False)."""
+ # Arrange
+ email = "active@example.com"
+ mock_send_request.return_value = {"data": False}
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is False
+ mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email})
+
+ def test_is_email_in_freeze_exception_returns_false(self, mock_send_request):
+ """Test that is_email_in_freeze returns False on exception."""
+ # Arrange
+ email = "error@example.com"
+ mock_send_request.side_effect = Exception("Network error")
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is False
+
+ def test_update_account_deletion_feedback(self, mock_send_request):
+ """Test updating account deletion feedback."""
+ # Arrange
+ email = "user@example.com"
+ feedback = "Service was too expensive"
+ expected_response = {"result": "success"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_account_deletion_feedback(email, feedback)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST", "/account/delete-feedback", json={"email": email, "feedback": feedback}
+ )
+
+ def test_is_tenant_owner_or_admin_owner(self, mock_db_session):
+ """Test tenant owner/admin check for owner role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.OWNER
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act - should not raise exception
+ BillingService.is_tenant_owner_or_admin(current_user)
+
+ # Assert
+ mock_db_session.query.assert_called_once()
+
+ def test_is_tenant_owner_or_admin_admin(self, mock_db_session):
+ """Test tenant owner/admin check for admin role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.ADMIN
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act - should not raise exception
+ BillingService.is_tenant_owner_or_admin(current_user)
+
+ # Assert
+ mock_db_session.query.assert_called_once()
+
+ def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session):
+ """Test tenant owner/admin check raises error for normal user."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.NORMAL
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+ def test_is_tenant_owner_or_admin_no_join_raises_error(self, mock_db_session):
+ """Test tenant owner/admin check raises error when join not found."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Tenant account join not found" in str(exc_info.value)
+
+
+class TestBillingServiceCacheManagement:
+ """Unit tests for billing cache management.
+
+ Tests cover:
+ - Billing info cache invalidation
+ - Proper Redis key formatting
+ """
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client."""
+ with patch("services.billing_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ def test_clean_billing_info_cache(self, mock_redis_client):
+ """Test cleaning billing info cache."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_key = f"tenant:{tenant_id}:billing_info"
+
+ # Act
+ BillingService.clean_billing_info_cache(tenant_id)
+
+ # Assert
+ mock_redis_client.delete.assert_called_once_with(expected_key)
+
+
+class TestBillingServicePartnerIntegration:
+ """Unit tests for partner integration features.
+
+ Tests cover:
+ - Partner tenant binding synchronization
+ - Click ID tracking
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_sync_partner_tenants_bindings(self, mock_send_request):
+ """Test syncing partner tenant bindings."""
+ # Arrange
+ account_id = "account-123"
+ partner_key = "partner-xyz"
+ click_id = "click-789"
+ expected_response = {"result": "success", "synced": True}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.sync_partner_tenants_bindings(account_id, partner_key, click_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "PUT", f"/partners/{partner_key}/tenants", json={"account_id": account_id, "click_id": click_id}
+ )
+
+
+class TestBillingServiceEdgeCases:
+ """Unit tests for edge cases and error scenarios.
+
+ Tests cover:
+ - Empty responses from billing API
+ - Malformed JSON responses
+ - Boundary conditions for rate limits
+ - Multiple subscription tiers
+ - Zero and negative usage deltas
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_info_empty_response(self, mock_send_request):
+ """Test handling of empty billing info response."""
+ # Arrange
+ tenant_id = "tenant-empty"
+ mock_send_request.return_value = {}
+
+ # Act
+ result = BillingService.get_info(tenant_id)
+
+ # Assert
+ assert result == {}
+ mock_send_request.assert_called_once()
+
+ def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
+ """Test updating tenant feature usage with zero delta (no change)."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ delta = 0 # No change
+ expected_response = {"result": "success", "history_id": "hist-uuid-zero"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_update_tenant_feature_plan_usage_large_negative_delta(self, mock_send_request):
+ """Test updating tenant feature usage with large negative delta."""
+ # Arrange
+ tenant_id = "tenant-456"
+ feature_key = "workflow"
+ delta = -1000 # Large consumption
+ expected_response = {"result": "success", "history_id": "hist-uuid-large"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once()
+
+ def test_get_knowledge_rate_limit_all_subscription_tiers(self, mock_send_request):
+ """Test knowledge rate limit for all subscription tiers."""
+ # Test SANDBOX tier
+ mock_send_request.return_value = {"limit": 10, "subscription_plan": CloudPlan.SANDBOX}
+ result = BillingService.get_knowledge_rate_limit("tenant-sandbox")
+ assert result["subscription_plan"] == CloudPlan.SANDBOX
+ assert result["limit"] == 10
+
+ # Test PROFESSIONAL tier
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+ result = BillingService.get_knowledge_rate_limit("tenant-pro")
+ assert result["subscription_plan"] == CloudPlan.PROFESSIONAL
+ assert result["limit"] == 100
+
+ # Test TEAM tier
+ mock_send_request.return_value = {"limit": 500, "subscription_plan": CloudPlan.TEAM}
+ result = BillingService.get_knowledge_rate_limit("tenant-team")
+ assert result["subscription_plan"] == CloudPlan.TEAM
+ assert result["limit"] == 500
+
+ def test_get_subscription_with_empty_optional_params(self, mock_send_request):
+ """Test subscription payment link with empty optional parameters."""
+ # Arrange
+ plan = "professional"
+ interval = "yearly"
+ expected_response = {"payment_link": "https://payment.example.com/checkout"}
+ mock_send_request.return_value = expected_response
+
+ # Act - empty email and tenant_id
+ result = BillingService.get_subscription(plan, interval, "", "")
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/subscription/payment-link",
+ params={"plan": plan, "interval": interval, "prefilled_email": "", "tenant_id": ""},
+ )
+
+ def test_get_invoices_with_empty_params(self, mock_send_request):
+ """Test invoice retrieval with empty parameters."""
+ # Arrange
+ expected_response = {"invoices": []}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_invoices("", "")
+
+ # Assert
+ assert result == expected_response
+ assert result["invoices"] == []
+
+ def test_refund_with_invalid_history_id_format(self, mock_send_request):
+ """Test refund with various history ID formats."""
+ # Arrange - test with different ID formats
+ test_ids = ["hist-123", "uuid-abc-def", "12345", ""]
+
+ for history_id in test_ids:
+ expected_response = {"result": "success", "history_id": history_id}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.refund_tenant_feature_plan_usage(history_id)
+
+ # Assert
+ assert result["history_id"] == history_id
+
+ def test_is_tenant_owner_or_admin_editor_role_raises_error(self):
+ """Test tenant owner/admin check raises error for editor role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
+
+ with patch("services.billing_service.db.session") as mock_session:
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+ def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self):
+ """Test tenant owner/admin check raises error for dataset operator role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
+
+ with patch("services.billing_service.db.session") as mock_session:
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+
+class TestBillingServiceIntegrationScenarios:
+ """Integration-style tests simulating real-world usage scenarios.
+
+ These tests combine multiple service methods to test common workflows:
+ - Complete subscription upgrade flow
+ - Usage tracking and refund workflow
+ - Rate limit boundary testing
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_subscription_upgrade_workflow(self, mock_send_request):
+ """Test complete subscription upgrade workflow."""
+ # Arrange
+ tenant_id = "tenant-upgrade"
+
+ # Step 1: Get current billing info
+ mock_send_request.return_value = {
+ "subscription_plan": "sandbox",
+ "billing_cycle": "monthly",
+ "status": "active",
+ }
+ current_info = BillingService.get_info(tenant_id)
+ assert current_info["subscription_plan"] == "sandbox"
+
+ # Step 2: Get payment link for upgrade
+ mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
+ payment_link = BillingService.get_subscription("professional", "monthly", "user@example.com", tenant_id)
+ assert "payment_link" in payment_link
+
+ # Step 3: Verify new rate limits after upgrade
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+ rate_limit = BillingService.get_knowledge_rate_limit(tenant_id)
+ assert rate_limit["subscription_plan"] == CloudPlan.PROFESSIONAL
+ assert rate_limit["limit"] == 100
+
+ def test_usage_tracking_and_refund_workflow(self, mock_send_request):
+ """Test usage tracking with subsequent refund."""
+ # Arrange
+ tenant_id = "tenant-usage"
+ feature_key = "workflow"
+
+ # Step 1: Consume credits
+ mock_send_request.return_value = {"result": "success", "history_id": "hist-consume-123"}
+ consume_result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, -10)
+ history_id = consume_result["history_id"]
+ assert history_id == "hist-consume-123"
+
+ # Step 2: Check current usage
+ mock_send_request.return_value = {"used": 10, "limit": 100, "remaining": 90}
+ usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+ assert usage["used"] == 10
+ assert usage["remaining"] == 90
+
+ # Step 3: Refund the usage
+ mock_send_request.return_value = {"result": "success", "history_id": history_id}
+ refund_result = BillingService.refund_tenant_feature_plan_usage(history_id)
+ assert refund_result["result"] == "success"
+
+ # Step 4: Verify usage after refund
+ mock_send_request.return_value = {"used": 0, "limit": 100, "remaining": 100}
+ updated_usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+ assert updated_usage["used"] == 0
+ assert updated_usage["remaining"] == 100
+
+ def test_compliance_download_multiple_requests_within_limit(self, mock_send_request):
+ """Test multiple compliance downloads within rate limit."""
+ # Arrange
+ account_id = "account-compliance"
+ tenant_id = "tenant-compliance"
+ doc_name = "compliance_report.pdf"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+
+ # Mock rate limiter to allow 3 requests (under limit of 4)
+ with (
+ patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", side_effect=[False, False, False]
+ ) as mock_is_limited,
+ patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment,
+ ):
+ mock_send_request.return_value = {"download_link": "https://example.com/download"}
+
+ # Act - Make 3 requests
+ for i in range(3):
+ result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+ assert "download_link" in result
+
+ # Assert - All 3 requests succeeded
+ assert mock_is_limited.call_count == 3
+ assert mock_increment.call_count == 3
+
+ def test_education_verification_and_activation_flow(self, mock_send_request):
+ """Test complete education verification and activation flow."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-edu"
+ account.email = "student@mit.edu"
+ account.current_tenant_id = "tenant-edu"
+
+ # Step 1: Search for institution
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ),
+ patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {
+ "institutions": [{"name": "Massachusetts Institute of Technology", "domain": "mit.edu"}]
+ }
+ institutions = BillingService.EducationIdentity.autocomplete("MIT")
+ assert len(institutions["institutions"]) > 0
+
+ # Step 2: Verify email
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ),
+ patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {"verified": True, "institution": "MIT"}
+ verify_result = BillingService.EducationIdentity.verify(account.id, account.email)
+ assert verify_result["verified"] is True
+
+ # Step 3: Check status
+ mock_send_request.return_value = {"verified": True, "institution": "MIT", "role": "student"}
+ status = BillingService.EducationIdentity.status(account.id)
+ assert status["verified"] is True
+
+ # Step 4: Activate education benefits
+ with (
+ patch.object(BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False),
+ patch.object(BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {"result": "success", "activated": True}
+ activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
+ assert activate_result["activated"] is True
From 2551f6f27967f663357c89f33f0f005a27913be1 Mon Sep 17 00:00:00 2001
From: jiangbo721
Date: Thu, 27 Nov 2025 10:51:48 +0800
Subject: [PATCH 44/63] =?UTF-8?q?feat:=20add=20APP=5FDEFAULT=5FACTIVE=5FRE?=
=?UTF-8?q?QUESTS=20as=20the=20default=20value=20for=20APP=5FAC=E2=80=A6?=
=?UTF-8?q?=20(#26930)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
api/.env.example | 1 +
api/configs/feature/__init__.py | 4 ++++
api/services/app_generate_service.py | 2 +-
api/services/rag_pipeline/pipeline_generate_service.py | 9 +++++----
api/tests/integration_tests/.env.example | 1 +
.../services/test_app_generate_service.py | 1 +
docker/.env.example | 2 ++
docker/docker-compose.yaml | 1 +
8 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/api/.env.example b/api/.env.example
index fbf0b12f40..50607f5b35 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -540,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
+APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 7cce3847b4..9c0c48c955 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -73,6 +73,10 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
+ APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
+ description="Default number of concurrent active requests per app (0 for unlimited)",
+ default=0,
+ )
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index bb1ea742d0..dc85929b98 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -135,7 +135,7 @@ class AppGenerateService:
Returns:
The maximum number of active requests allowed
"""
- app_limit = app.max_active_requests or 0
+ app_limit = app.max_active_requests or dify_config.APP_DEFAULT_ACTIVE_REQUESTS
config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
# Filter out infinite (0) values and return the minimum, or 0 if both are infinite
diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py
index e6cee64df6..f397b28283 100644
--- a/api/services/rag_pipeline/pipeline_generate_service.py
+++ b/api/services/rag_pipeline/pipeline_generate_service.py
@@ -53,10 +53,11 @@ class PipelineGenerateService:
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
- max_active_requests = app_model.max_active_requests
- if max_active_requests is None:
- max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
- return max_active_requests
+ app_limit = app_model.max_active_requests or dify_config.APP_DEFAULT_ACTIVE_REQUESTS
+ config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
+ # Filter out infinite (0) values and return the minimum, or 0 if both are infinite
+ limits = [limit for limit in [app_limit, config_limit] if limit > 0]
+ return min(limits) if limits else 0
@classmethod
def generate_single_iteration(
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 46d13079db..e508ceef66 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -175,6 +175,7 @@ MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200
+APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
index 0f9ed94017..476f58585d 100644
--- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
@@ -82,6 +82,7 @@ class TestAppGenerateService:
# Setup dify_config mock returns
mock_dify_config.BILLING_ENABLED = False
mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
+ mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_global_dify_config.BILLING_ENABLED = False
diff --git a/docker/.env.example b/docker/.env.example
index 0bfdc6b495..c9981baaba 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -133,6 +133,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
+# The default number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
+APP_DEFAULT_ACTIVE_REQUESTS=0
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 0302612045..17f33bbf72 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -34,6 +34,7 @@ x-shared-env: &shared-api-worker-env
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
+ APP_DEFAULT_ACTIVE_REQUESTS: ${APP_DEFAULT_ACTIVE_REQUESTS:-0}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
From 2f6b3f1c5fc54121765d2201d8dd6bf0c89a5cc3 Mon Sep 17 00:00:00 2001
From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Date: Thu, 27 Nov 2025 10:54:00 +0800
Subject: [PATCH 45/63] hotfix: fix _extract_filename for rfc 5987 (#26230)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
---
api/factories/file_factory.py | 43 ++++++-
.../unit_tests/factories/test_file_factory.py | 119 +++++++++++++++++-
2 files changed, 156 insertions(+), 6 deletions(-)
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 2316e45179..737a79f2b0 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -1,5 +1,6 @@
import mimetypes
import os
+import re
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
@@ -268,15 +269,47 @@ def _build_from_remote_url(
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
- filename = None
+ filename: str | None = None
# Try to extract from Content-Disposition header first
if content_disposition:
- _, params = parse_options_header(content_disposition)
- # RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename
- filename = params.get("filename*") or params.get("filename")
+ # Manually extract filename* parameter since parse_options_header doesn't support it
+ filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
+ if filename_star_match:
+ raw_star = filename_star_match.group(1).strip()
+ # Remove trailing quotes if present
+ raw_star = raw_star.removesuffix('"')
+ # format: charset'lang'value
+ try:
+ parts = raw_star.split("'", 2)
+ charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
+ value = parts[2] if len(parts) == 3 else parts[-1]
+ filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
+ except Exception:
+ # Fallback: try to extract value after the last single quote
+ if "''" in raw_star:
+ filename = urllib.parse.unquote(raw_star.split("''")[-1])
+ else:
+ filename = urllib.parse.unquote(raw_star)
+
+ if not filename:
+ # Fallback to regular filename parameter
+ _, params = parse_options_header(content_disposition)
+ raw = params.get("filename")
+ if raw:
+ # Strip surrounding quotes and percent-decode if present
+ if len(raw) >= 2 and raw[0] == raw[-1] == '"':
+ raw = raw[1:-1]
+ filename = urllib.parse.unquote(raw)
# Fallback to URL path if no filename from header
if not filename:
- filename = os.path.basename(url_path)
+ candidate = os.path.basename(url_path)
+ filename = urllib.parse.unquote(candidate) if candidate else None
+ # Defense-in-depth: ensure basename only
+ if filename:
+ filename = os.path.basename(filename)
+ # Return None if filename is empty or only whitespace
+ if not filename or not filename.strip():
+ filename = None
return filename or None
diff --git a/api/tests/unit_tests/factories/test_file_factory.py b/api/tests/unit_tests/factories/test_file_factory.py
index 777fe5a6e7..e5f45044fa 100644
--- a/api/tests/unit_tests/factories/test_file_factory.py
+++ b/api/tests/unit_tests/factories/test_file_factory.py
@@ -2,7 +2,7 @@ import re
import pytest
-from factories.file_factory import _get_remote_file_info
+from factories.file_factory import _extract_filename, _get_remote_file_info
class _FakeResponse:
@@ -113,3 +113,120 @@ class TestGetRemoteFileInfo:
# Should generate a random hex filename with .bin extension
assert re.match(r"^[0-9a-f]{32}\.bin$", filename) is not None
assert mime_type == "application/octet-stream"
+
+
+class TestExtractFilename:
+ """Tests for _extract_filename function focusing on RFC5987 parsing and security."""
+
+ def test_no_content_disposition_uses_url_basename(self):
+ """Test that URL basename is used when no Content-Disposition header."""
+ result = _extract_filename("http://example.com/path/file.txt", None)
+ assert result == "file.txt"
+
+ def test_no_content_disposition_with_percent_encoded_url(self):
+ """Test that percent-encoded URL basename is decoded."""
+ result = _extract_filename("http://example.com/path/file%20name.txt", None)
+ assert result == "file name.txt"
+
+ def test_no_content_disposition_empty_url_path(self):
+ """Test that empty URL path returns None."""
+ result = _extract_filename("http://example.com/", None)
+ assert result is None
+
+ def test_simple_filename_header(self):
+ """Test basic filename extraction from Content-Disposition."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="test.txt"')
+ assert result == "test.txt"
+
+ def test_quoted_filename_with_spaces(self):
+ """Test filename with spaces in quotes."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="my file.txt"')
+ assert result == "my file.txt"
+
+ def test_unquoted_filename(self):
+ """Test unquoted filename."""
+ result = _extract_filename("http://example.com/", "attachment; filename=test.txt")
+ assert result == "test.txt"
+
+ def test_percent_encoded_filename(self):
+ """Test percent-encoded filename."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="file%20name.txt"')
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_utf8(self):
+ """Test RFC5987 filename* with UTF-8 encoding."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_chinese(self):
+ """Test RFC5987 filename* with Chinese characters."""
+ result = _extract_filename(
+ "http://example.com/", "attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95%E6%96%87%E4%BB%B6.txt"
+ )
+ assert result == "测试文件.txt"
+
+ def test_rfc5987_filename_star_with_language(self):
+ """Test RFC5987 filename* with language tag."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8'en'file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_fallback_charset(self):
+ """Test RFC5987 filename* with fallback charset."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_malformed_fallback(self):
+ """Test RFC5987 filename* with malformed format falls back to simple unquote."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=malformed%20filename.txt")
+ assert result == "malformed filename.txt"
+
+ def test_filename_star_takes_precedence_over_filename(self):
+ """Test that filename* takes precedence over filename."""
+ test_string = 'attachment; filename="old.txt"; filename*=UTF-8\'\'new.txt"'
+ result = _extract_filename("http://example.com/", test_string)
+ assert result == "new.txt"
+
+ def test_path_injection_protection(self):
+ """Test that path injection attempts are blocked by os.path.basename."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="../../../etc/passwd"')
+ assert result == "passwd"
+
+ def test_path_injection_protection_rfc5987(self):
+ """Test that path injection attempts in RFC5987 are blocked."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8''..%2F..%2F..%2Fetc%2Fpasswd")
+ assert result == "passwd"
+
+ def test_empty_filename_returns_none(self):
+ """Test that empty filename returns None."""
+ result = _extract_filename("http://example.com/", 'attachment; filename=""')
+ assert result is None
+
+ def test_whitespace_only_filename_returns_none(self):
+ """Test that whitespace-only filename returns None."""
+ result = _extract_filename("http://example.com/", 'attachment; filename=" "')
+ assert result is None
+
+ def test_complex_rfc5987_encoding(self):
+ """Test complex RFC5987 encoding with special characters."""
+ result = _extract_filename(
+ "http://example.com/",
+ "attachment; filename*=UTF-8''%E4%B8%AD%E6%96%87%E6%96%87%E4%BB%B6%20%28%E5%89%AF%E6%9C%AC%29.pdf",
+ )
+ assert result == "中文文件 (副本).pdf"
+
+ def test_iso8859_1_encoding(self):
+ """Test ISO-8859-1 encoding in RFC5987."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=ISO-8859-1''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_encoding_error_fallback(self):
+ """Test that encoding errors fall back to safe ASCII filename."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=INVALID-CHARSET''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_mixed_quotes_and_encoding(self):
+ """Test filename with mixed quotes and percent encoding."""
+ result = _extract_filename(
+ "http://example.com/", 'attachment; filename="file%20with%20quotes%20%26%20encoding.txt"'
+ )
+ assert result == "file with quotes & encoding.txt"
From 09a8046b10809d583825f3fed400ea47c1705f65 Mon Sep 17 00:00:00 2001
From: Will
Date: Thu, 27 Nov 2025 10:56:21 +0800
Subject: [PATCH 46/63] fix: querying webhook trigger issue (#28753)
---
api/controllers/console/app/workflow_trigger.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index b3e5c9619f..5d16e4f979 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -43,7 +43,7 @@ console_ns.schema_model(
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
- @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
+ @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
From b786e101e52a4f763c4818f4f7637b191a611c09 Mon Sep 17 00:00:00 2001
From: Will
Date: Thu, 27 Nov 2025 10:58:35 +0800
Subject: [PATCH 47/63] fix: querying and setting the system default model
(#28743)
---
api/controllers/console/workspace/models.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 8e402b4bae..c820a8d1f2 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any
+from typing import Any, cast
from flask import request
from flask_restx import Resource
@@ -26,7 +26,7 @@ class ParserGetDefault(BaseModel):
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
- model: str
+ model: str | None = None
provider: str | None = None
model_settings: list[Inner]
@@ -150,7 +150,7 @@ console_ns.schema_model(
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
- @console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True)
+ @console_ns.expect(console_ns.models[ParserGetDefault.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -186,7 +186,7 @@ class DefaultModelApi(Resource):
tenant_id=tenant_id,
model_type=model_setting.model_type,
provider=model_setting.provider,
- model=model_setting.model,
+ model=cast(str, model_setting.model),
)
except Exception as ex:
logger.exception(
From 7efa0df1fd119037386b5627652e02e621f0e1d1 Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 18:59:17 -0800
Subject: [PATCH 48/63] Add comprehensive API/controller tests for dataset
endpoints (list, create, update, delete, documents, segments, hit testing,
external datasets) (#28750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/services/controller_api.py | 1082 +++++++++++++++++
1 file changed, 1082 insertions(+)
create mode 100644 api/tests/unit_tests/services/controller_api.py
diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py
new file mode 100644
index 0000000000..762d7b9090
--- /dev/null
+++ b/api/tests/unit_tests/services/controller_api.py
@@ -0,0 +1,1082 @@
+"""
+Comprehensive API/Controller tests for Dataset endpoints.
+
+This module contains extensive integration tests for the dataset-related
+controller endpoints, testing the HTTP API layer that exposes dataset
+functionality through REST endpoints.
+
+The controller endpoints provide HTTP access to:
+- Dataset CRUD operations (list, create, update, delete)
+- Document management operations
+- Segment management operations
+- Hit testing (retrieval testing) operations
+- External dataset and knowledge API operations
+
+These tests verify that:
+- HTTP requests are properly routed to service methods
+- Request validation works correctly
+- Response formatting is correct
+- Authentication and authorization are enforced
+- Error handling returns appropriate HTTP status codes
+- Request/response serialization works properly
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The controller layer in Dify uses Flask-RESTX to provide RESTful API endpoints.
+Controllers act as a thin layer between HTTP requests and service methods,
+handling:
+
+1. Request Parsing: Extracting and validating parameters from HTTP requests
+2. Authentication: Verifying user identity and permissions
+3. Authorization: Checking if user has permission to perform operations
+4. Service Invocation: Calling appropriate service methods
+5. Response Formatting: Serializing service results to HTTP responses
+6. Error Handling: Converting exceptions to appropriate HTTP status codes
+
+Key Components:
+- Flask-RESTX Resources: Define endpoint classes with HTTP methods
+- Decorators: Handle authentication, authorization, and setup requirements
+- Request Parsers: Validate and extract request parameters
+- Response Models: Define response structure for Swagger documentation
+- Error Handlers: Convert exceptions to HTTP error responses
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. HTTP Request/Response Testing:
+ - GET, POST, PATCH, DELETE methods
+ - Query parameters and request body validation
+ - Response status codes and body structure
+ - Headers and content types
+
+2. Authentication and Authorization:
+ - Login required checks
+ - Account initialization checks
+ - Permission validation
+ - Role-based access control
+
+3. Request Validation:
+ - Required parameter validation
+ - Parameter type validation
+ - Parameter range validation
+ - Custom validation rules
+
+4. Error Handling:
+ - 400 Bad Request (validation errors)
+ - 401 Unauthorized (authentication errors)
+ - 403 Forbidden (authorization errors)
+ - 404 Not Found (resource not found)
+ - 500 Internal Server Error (unexpected errors)
+
+5. Service Integration:
+ - Service method invocation
+ - Service method parameter passing
+ - Service method return value handling
+ - Service exception handling
+
+================================================================================
+"""
+
+from unittest.mock import Mock, patch
+from uuid import uuid4
+
+import pytest
+from flask import Flask
+from flask_restx import Api
+
+from controllers.console.datasets.datasets import DatasetApi, DatasetListApi
+from controllers.console.datasets.external import (
+ ExternalApiTemplateListApi,
+)
+from controllers.console.datasets.hit_testing import HitTestingApi
+from models.dataset import Dataset, DatasetPermissionEnum
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of models or services changes, we only
+# need to update the factory methods rather than every individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class ControllerApiTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for controller API tests.
+
+ This factory provides static methods to create mock objects for:
+ - Flask application and test client setup
+ - Dataset instances and related models
+ - User and authentication context
+ - HTTP request/response objects
+ - Service method return values
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_flask_app():
+ """
+ Create a Flask test application for API testing.
+
+ Returns:
+ Flask application instance configured for testing
+ """
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ app.config["SECRET_KEY"] = "test-secret-key"
+ return app
+
+ @staticmethod
+ def create_api_instance(app):
+ """
+ Create a Flask-RESTX API instance.
+
+ Args:
+ app: Flask application instance
+
+ Returns:
+ Api instance configured for the application
+ """
+ api = Api(app, doc="/docs/")
+ return api
+
+ @staticmethod
+ def create_test_client(app, api, resource_class, route):
+ """
+ Create a Flask test client with a resource registered.
+
+ Args:
+ app: Flask application instance
+ api: Flask-RESTX API instance
+ resource_class: Resource class to register
+ route: URL route for the resource
+
+ Returns:
+ Flask test client instance
+ """
+ api.add_resource(resource_class, route)
+ return app.test_client()
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ name: str = "Test Dataset",
+ tenant_id: str = "tenant-123",
+ permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset instance.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ name: Name of the dataset
+ tenant_id: Tenant identifier
+ permission: Dataset permission level
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ dataset.name = name
+ dataset.tenant_id = tenant_id
+ dataset.permission = permission
+ dataset.to_dict.return_value = {
+ "id": dataset_id,
+ "name": name,
+ "tenant_id": tenant_id,
+ "permission": permission.value,
+ }
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_user_mock(
+ user_id: str = "user-123",
+ tenant_id: str = "tenant-123",
+ is_dataset_editor: bool = True,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock user/account instance.
+
+ Args:
+ user_id: Unique identifier for the user
+ tenant_id: Tenant identifier
+ is_dataset_editor: Whether user has dataset editor permissions
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a user/account instance
+ """
+ user = Mock()
+ user.id = user_id
+ user.current_tenant_id = tenant_id
+ user.is_dataset_editor = is_dataset_editor
+ user.has_edit_permission = True
+ user.is_dataset_operator = False
+ for key, value in kwargs.items():
+ setattr(user, key, value)
+ return user
+
+ @staticmethod
+ def create_paginated_response(items, total, page=1, per_page=20):
+ """
+ Create a mock paginated response.
+
+ Args:
+ items: List of items in the current page
+ total: Total number of items
+ page: Current page number
+ per_page: Items per page
+
+ Returns:
+ Mock paginated response object
+ """
+ response = Mock()
+ response.items = items
+ response.total = total
+ response.page = page
+ response.per_page = per_page
+ response.pages = (total + per_page - 1) // per_page
+ return response
+
+
+# ============================================================================
+# Tests for Dataset List Endpoint (GET /datasets)
+# ============================================================================
+
+
+class TestDatasetListApi:
+ """
+ Comprehensive API tests for DatasetListApi (GET /datasets endpoint).
+
+ This test class covers the dataset listing functionality through the
+ HTTP API, including pagination, search, filtering, and permissions.
+
+ The GET /datasets endpoint:
+ 1. Requires authentication and account initialization
+ 2. Supports pagination (page, limit parameters)
+ 3. Supports search by keyword
+ 4. Supports filtering by tag IDs
+ 5. Supports including all datasets (for admins)
+ 6. Returns paginated list of datasets
+
+ Test scenarios include:
+ - Successful dataset listing with pagination
+ - Search functionality
+ - Tag filtering
+ - Permission-based filtering
+ - Error handling (authentication, authorization)
+ """
+
+ @pytest.fixture
+ def app(self):
+ """
+ Create Flask test application.
+
+ Provides a Flask application instance configured for testing.
+ """
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """
+ Create Flask-RESTX API instance.
+
+ Provides an API instance for registering resources.
+ """
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """
+ Create test client with DatasetListApi registered.
+
+ Provides a Flask test client that can make HTTP requests to
+ the dataset list endpoint.
+ """
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetListApi, "/datasets")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """
+ Mock current user and tenant context.
+
+ Provides mocked current_account_with_tenant function that returns
+ a user and tenant ID for testing authentication.
+ """
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_datasets_success(self, client, mock_current_user):
+ """
+ Test successful retrieval of dataset list.
+
+ Verifies that when authentication passes, the endpoint returns
+ a paginated list of datasets.
+
+ This test ensures:
+ - Authentication is checked
+ - Service method is called with correct parameters
+ - Response has correct structure
+ - Status code is 200
+ """
+ # Arrange
+ datasets = [
+ ControllerApiTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", name=f"Dataset {i}")
+ for i in range(3)
+ ]
+
+ paginated_response = ControllerApiTestDataFactory.create_paginated_response(
+ items=datasets, total=3, page=1, per_page=20
+ )
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets, 3)
+
+ # Act
+ response = client.get("/datasets?page=1&limit=20")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "data" in data
+ assert len(data["data"]) == 3
+ assert data["total"] == 3
+ assert data["page"] == 1
+ assert data["limit"] == 20
+
+ # Verify service was called
+ mock_get_datasets.assert_called_once()
+
+ def test_get_datasets_with_search(self, client, mock_current_user):
+ """
+ Test dataset listing with search keyword.
+
+ Verifies that search functionality works correctly through the API.
+
+ This test ensures:
+ - Search keyword is passed to service method
+ - Filtered results are returned
+ - Response structure is correct
+ """
+ # Arrange
+ search_keyword = "test"
+ datasets = [ControllerApiTestDataFactory.create_dataset_mock(dataset_id="dataset-1", name="Test Dataset")]
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets, 1)
+
+ # Act
+ response = client.get(f"/datasets?keyword={search_keyword}")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert len(data["data"]) == 1
+
+ # Verify search keyword was passed
+ call_args = mock_get_datasets.call_args
+ assert call_args[1]["search"] == search_keyword
+
+ def test_get_datasets_with_pagination(self, client, mock_current_user):
+ """
+ Test dataset listing with pagination parameters.
+
+ Verifies that pagination works correctly through the API.
+
+ This test ensures:
+ - Page and limit parameters are passed correctly
+ - Pagination metadata is included in response
+ - Correct datasets are returned for the page
+ """
+ # Arrange
+ datasets = [
+ ControllerApiTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", name=f"Dataset {i}")
+ for i in range(5)
+ ]
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets[:3], 5) # First page with 3 items
+
+ # Act
+ response = client.get("/datasets?page=1&limit=3")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert len(data["data"]) == 3
+ assert data["page"] == 1
+ assert data["limit"] == 3
+
+ # Verify pagination parameters were passed
+ call_args = mock_get_datasets.call_args
+ assert call_args[0][0] == 1 # page
+ assert call_args[0][1] == 3 # per_page
+
+
+# ============================================================================
+# Tests for Dataset Detail Endpoint (GET /datasets/{id})
+# ============================================================================
+
+
+class TestDatasetApiGet:
+ """
+ Comprehensive API tests for DatasetApi GET method (GET /datasets/{id} endpoint).
+
+ This test class covers the single dataset retrieval functionality through
+ the HTTP API.
+
+ The GET /datasets/{id} endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates dataset exists
+ 3. Checks user permissions
+ 4. Returns dataset details
+
+ Test scenarios include:
+ - Successful dataset retrieval
+ - Dataset not found (404)
+ - Permission denied (403)
+ - Authentication required
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with DatasetApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_dataset_success(self, client, mock_current_user):
+ """
+ Test successful retrieval of a single dataset.
+
+ Verifies that when authentication and permissions pass, the endpoint
+ returns dataset details.
+
+ This test ensures:
+ - Authentication is checked
+ - Dataset existence is validated
+ - Permissions are checked
+ - Dataset details are returned
+ - Status code is 200
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id, name="Test Dataset")
+
+ with (
+ patch("controllers.console.datasets.datasets.DatasetService.get_dataset") as mock_get_dataset,
+ patch("controllers.console.datasets.datasets.DatasetService.check_dataset_permission") as mock_check_perm,
+ ):
+ mock_get_dataset.return_value = dataset
+ mock_check_perm.return_value = None # No exception = permission granted
+
+ # Act
+ response = client.get(f"/datasets/{dataset_id}")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert data["id"] == dataset_id
+ assert data["name"] == "Test Dataset"
+
+ # Verify service methods were called
+ mock_get_dataset.assert_called_once_with(dataset_id)
+ mock_check_perm.assert_called_once()
+
+ def test_get_dataset_not_found(self, client, mock_current_user):
+ """
+ Test error handling when dataset is not found.
+
+ Verifies that when dataset doesn't exist, a 404 error is returned.
+
+ This test ensures:
+ - 404 status code is returned
+ - Error message is appropriate
+ - Service method is called
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+
+ with (
+ patch("controllers.console.datasets.datasets.DatasetService.get_dataset") as mock_get_dataset,
+ patch("controllers.console.datasets.datasets.DatasetService.check_dataset_permission") as mock_check_perm,
+ ):
+ mock_get_dataset.return_value = None # Dataset not found
+
+ # Act
+ response = client.get(f"/datasets/{dataset_id}")
+
+ # Assert
+ assert response.status_code == 404
+
+ # Verify service was called
+ mock_get_dataset.assert_called_once()
+
+
+# ============================================================================
+# Tests for Dataset Create Endpoint (POST /datasets)
+# ============================================================================
+
+
+class TestDatasetApiCreate:
+ """
+ Comprehensive API tests for DatasetApi POST method (POST /datasets endpoint).
+
+ This test class covers the dataset creation functionality through the HTTP API.
+
+ The POST /datasets endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates request body
+ 3. Creates dataset via service
+ 4. Returns created dataset
+
+ Test scenarios include:
+ - Successful dataset creation
+ - Request validation errors
+ - Duplicate name errors
+ - Authentication required
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with DatasetApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_create_dataset_success(self, client, mock_current_user):
+ """
+ Test successful creation of a dataset.
+
+ Verifies that when all validation passes, a new dataset is created
+ and returned.
+
+ This test ensures:
+ - Request body is validated
+ - Service method is called with correct parameters
+ - Created dataset is returned
+ - Status code is 201
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id, name="New Dataset")
+
+ request_data = {
+ "name": "New Dataset",
+ "description": "Test description",
+ "permission": "only_me",
+ }
+
+ with patch("controllers.console.datasets.datasets.DatasetService.create_empty_dataset") as mock_create:
+ mock_create.return_value = dataset
+
+ # Act
+ response = client.post(
+ "/datasets",
+ json=request_data,
+ content_type="application/json",
+ )
+
+ # Assert
+ assert response.status_code == 201
+ data = response.get_json()
+ assert data["id"] == dataset_id
+ assert data["name"] == "New Dataset"
+
+ # Verify service was called
+ mock_create.assert_called_once()
+
+
+# ============================================================================
+# Tests for Hit Testing Endpoint (POST /datasets/{id}/hit-testing)
+# ============================================================================
+
+
+class TestHitTestingApi:
+ """
+ Comprehensive API tests for HitTestingApi (POST /datasets/{id}/hit-testing endpoint).
+
+ This test class covers the hit testing (retrieval testing) functionality
+ through the HTTP API.
+
+ The POST /datasets/{id}/hit-testing endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates dataset exists and user has permission
+ 3. Validates query parameters
+ 4. Performs retrieval testing
+ 5. Returns test results
+
+ Test scenarios include:
+ - Successful hit testing
+ - Query validation errors
+ - Dataset not found
+ - Permission denied
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with HitTestingApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(
+ app, api, HitTestingApi, "/datasets//hit-testing"
+ )
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.hit_testing.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_hit_testing_success(self, client, mock_current_user):
+ """
+ Test successful hit testing operation.
+
+ Verifies that when all validation passes, hit testing is performed
+ and results are returned.
+
+ This test ensures:
+ - Dataset validation passes
+ - Query validation passes
+ - Hit testing service is called
+ - Results are returned
+ - Status code is 200
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
+
+ request_data = {
+ "query": "test query",
+ "top_k": 10,
+ }
+
+ expected_result = {
+ "query": {"content": "test query"},
+ "records": [
+ {"content": "Result 1", "score": 0.95},
+ {"content": "Result 2", "score": 0.85},
+ ],
+ }
+
+ with (
+ patch(
+ "controllers.console.datasets.hit_testing.HitTestingApi.get_and_validate_dataset"
+ ) as mock_get_dataset,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.parse_args") as mock_parse_args,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.hit_testing_args_check") as mock_check_args,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.perform_hit_testing") as mock_perform,
+ ):
+ mock_get_dataset.return_value = dataset
+ mock_parse_args.return_value = request_data
+ mock_check_args.return_value = None # No validation error
+ mock_perform.return_value = expected_result
+
+ # Act
+ response = client.post(
+ f"/datasets/{dataset_id}/hit-testing",
+ json=request_data,
+ content_type="application/json",
+ )
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "query" in data
+ assert "records" in data
+ assert len(data["records"]) == 2
+
+ # Verify methods were called
+ mock_get_dataset.assert_called_once()
+ mock_parse_args.assert_called_once()
+ mock_check_args.assert_called_once()
+ mock_perform.assert_called_once()
+
+
+# ============================================================================
+# Tests for External Dataset Endpoints
+# ============================================================================
+
+
+class TestExternalDatasetApi:
+ """
+ Comprehensive API tests for External Dataset endpoints.
+
+ This test class covers the external knowledge API and external dataset
+ management functionality through the HTTP API.
+
+ Endpoints covered:
+ - GET /datasets/external-knowledge-api - List external knowledge APIs
+ - POST /datasets/external-knowledge-api - Create external knowledge API
+ - GET /datasets/external-knowledge-api/{id} - Get external knowledge API
+ - PATCH /datasets/external-knowledge-api/{id} - Update external knowledge API
+ - DELETE /datasets/external-knowledge-api/{id} - Delete external knowledge API
+ - POST /datasets/external - Create external dataset
+
+ Test scenarios include:
+ - Successful CRUD operations
+ - Request validation
+ - Authentication and authorization
+ - Error handling
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client_list(self, app, api):
+ """Create test client for external knowledge API list endpoint."""
+ return ControllerApiTestDataFactory.create_test_client(
+ app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api"
+ )
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.external.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock(is_dataset_editor=True)
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_external_knowledge_apis_success(self, client_list, mock_current_user):
+ """
+ Test successful retrieval of external knowledge API list.
+
+ Verifies that the endpoint returns a paginated list of external
+ knowledge APIs.
+
+ This test ensures:
+ - Authentication is checked
+ - Service method is called
+ - Paginated response is returned
+ - Status code is 200
+ """
+ # Arrange
+ apis = [{"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} for i in range(3)]
+
+ with patch(
+ "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis"
+ ) as mock_get_apis:
+ mock_get_apis.return_value = (apis, 3)
+
+ # Act
+ response = client_list.get("/datasets/external-knowledge-api?page=1&limit=20")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "data" in data
+ assert len(data["data"]) == 3
+ assert data["total"] == 3
+
+ # Verify service was called
+ mock_get_apis.assert_called_once()
+
+
+# ============================================================================
+# Additional Documentation and Notes
+# ============================================================================
+#
+# This test suite covers the core API endpoints for dataset operations.
+# Additional test scenarios that could be added:
+#
+# 1. Document Endpoints:
+# - POST /datasets/{id}/documents - Upload/create documents
+# - GET /datasets/{id}/documents - List documents
+# - GET /datasets/{id}/documents/{doc_id} - Get document details
+# - PATCH /datasets/{id}/documents/{doc_id} - Update document
+# - DELETE /datasets/{id}/documents/{doc_id} - Delete document
+# - POST /datasets/{id}/documents/batch - Batch operations
+#
+# 2. Segment Endpoints:
+# - GET /datasets/{id}/segments - List segments
+# - GET /datasets/{id}/segments/{segment_id} - Get segment details
+# - PATCH /datasets/{id}/segments/{segment_id} - Update segment
+# - DELETE /datasets/{id}/segments/{segment_id} - Delete segment
+#
+# 3. Dataset Update/Delete Endpoints:
+# - PATCH /datasets/{id} - Update dataset
+# - DELETE /datasets/{id} - Delete dataset
+#
+# 4. Advanced Scenarios:
+# - File upload handling
+# - Large payload handling
+# - Concurrent request handling
+# - Rate limiting
+# - CORS headers
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
+
+
+# ============================================================================
+# API Testing Best Practices
+# ============================================================================
+#
+# When writing API tests, consider the following best practices:
+#
+# 1. Test Structure:
+# - Use descriptive test names that explain what is being tested
+# - Follow Arrange-Act-Assert pattern
+# - Keep tests focused on a single scenario
+# - Use fixtures for common setup
+#
+# 2. Mocking Strategy:
+# - Mock external dependencies (database, services, etc.)
+# - Mock authentication and authorization
+# - Use realistic mock data
+# - Verify mock calls to ensure correct integration
+#
+# 3. Assertions:
+# - Verify HTTP status codes
+# - Verify response structure
+# - Verify response data values
+# - Verify service method calls
+# - Verify error messages when appropriate
+#
+# 4. Error Testing:
+# - Test all error paths (400, 401, 403, 404, 500)
+# - Test validation errors
+# - Test authentication failures
+# - Test authorization failures
+# - Test not found scenarios
+#
+# 5. Edge Cases:
+# - Test with empty data
+# - Test with missing required fields
+# - Test with invalid data types
+# - Test with boundary values
+# - Test with special characters
+#
+# ============================================================================
+
+
+# ============================================================================
+# Flask-RESTX Resource Testing Patterns
+# ============================================================================
+#
+# Flask-RESTX resources are tested using Flask's test client. The typical
+# pattern involves:
+#
+# 1. Creating a Flask test application
+# 2. Creating a Flask-RESTX API instance
+# 3. Registering the resource with a route
+# 4. Creating a test client
+# 5. Making HTTP requests through the test client
+# 6. Asserting on the response
+#
+# Example pattern:
+#
+# app = Flask(__name__)
+# app.config["TESTING"] = True
+# api = Api(app)
+# api.add_resource(MyResource, "/my-endpoint")
+# client = app.test_client()
+# response = client.get("/my-endpoint")
+# assert response.status_code == 200
+#
+# Decorators on resources (like @login_required) need to be mocked or
+# bypassed in tests. This is typically done by mocking the decorator
+# functions or the authentication functions they call.
+#
+# ============================================================================
+
+
+# ============================================================================
+# Request/Response Validation
+# ============================================================================
+#
+# API endpoints use Flask-RESTX request parsers to validate incoming requests.
+# These parsers:
+#
+# 1. Extract parameters from query strings, form data, or JSON body
+# 2. Validate parameter types (string, integer, float, boolean, etc.)
+# 3. Validate parameter ranges and constraints
+# 4. Provide default values when parameters are missing
+# 5. Raise BadRequest exceptions when validation fails
+#
+# Response formatting is handled by Flask-RESTX's marshal_with decorator
+# or marshal function, which:
+#
+# 1. Formats response data according to defined models
+# 2. Handles nested objects and lists
+# 3. Filters out fields not in the model
+# 4. Provides consistent response structure
+#
+# Tests should verify:
+# - Request validation works correctly
+# - Invalid requests return 400 Bad Request
+# - Response structure matches the defined model
+# - Response data values are correct
+#
+# ============================================================================
+
+
+# ============================================================================
+# Authentication and Authorization Testing
+# ============================================================================
+#
+# Most API endpoints require authentication and authorization. Testing these
+# aspects involves:
+#
+# 1. Authentication Testing:
+# - Test that unauthenticated requests are rejected (401)
+# - Test that authenticated requests are accepted
+# - Mock the authentication decorators/functions
+# - Verify user context is passed correctly
+#
+# 2. Authorization Testing:
+# - Test that unauthorized requests are rejected (403)
+# - Test that authorized requests are accepted
+# - Test different user roles and permissions
+# - Verify permission checks are performed
+#
+# 3. Common Patterns:
+# - Mock current_account_with_tenant() to return test user
+# - Mock permission check functions
+# - Test with different user roles (admin, editor, operator, etc.)
+# - Test with different permission levels (only_me, all_team, etc.)
+#
+# ============================================================================
+
+
+# ============================================================================
+# Error Handling in API Tests
+# ============================================================================
+#
+# API endpoints should handle errors gracefully and return appropriate HTTP
+# status codes. Testing error handling involves:
+#
+# 1. Service Exception Mapping:
+# - ValueError -> 400 Bad Request
+# - NotFound -> 404 Not Found
+# - Forbidden -> 403 Forbidden
+# - Unauthorized -> 401 Unauthorized
+# - Internal errors -> 500 Internal Server Error
+#
+# 2. Validation Error Testing:
+# - Test missing required parameters
+# - Test invalid parameter types
+# - Test parameter range violations
+# - Test custom validation rules
+#
+# 3. Error Response Structure:
+# - Verify error status code
+# - Verify error message is included
+# - Verify error structure is consistent
+# - Verify error details are helpful
+#
+# ============================================================================
+
+
+# ============================================================================
+# Performance and Scalability Considerations
+# ============================================================================
+#
+# While unit tests focus on correctness, API tests should also consider:
+#
+# 1. Response Time:
+# - Tests should complete quickly
+# - Avoid actual database or network calls
+# - Use mocks for slow operations
+#
+# 2. Resource Usage:
+# - Tests should not consume excessive memory
+# - Tests should clean up after themselves
+# - Use fixtures for resource management
+#
+# 3. Test Isolation:
+# - Tests should not depend on each other
+# - Tests should not share state
+# - Each test should be independently runnable
+#
+# 4. Maintainability:
+# - Tests should be easy to understand
+# - Tests should be easy to modify
+# - Use descriptive names and comments
+# - Follow consistent patterns
+#
+# ============================================================================
From 4ca4493084795eb065e03421e0ca8a67e832213a Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 19:00:10 -0800
Subject: [PATCH 49/63] Add comprehensive unit tests for MetadataService
(dataset metadata CRUD operations and filtering) (#28748)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/services/dataset_metadata.py | 1068 +++++++++++++++++
1 file changed, 1068 insertions(+)
create mode 100644 api/tests/unit_tests/services/dataset_metadata.py
diff --git a/api/tests/unit_tests/services/dataset_metadata.py b/api/tests/unit_tests/services/dataset_metadata.py
new file mode 100644
index 0000000000..5ba18d8dc0
--- /dev/null
+++ b/api/tests/unit_tests/services/dataset_metadata.py
@@ -0,0 +1,1068 @@
+"""
+Comprehensive unit tests for MetadataService.
+
+This module contains extensive unit tests for the MetadataService class,
+which handles dataset metadata CRUD operations and filtering/querying functionality.
+
+The MetadataService provides methods for:
+- Creating, reading, updating, and deleting metadata fields
+- Managing built-in metadata fields
+- Updating document metadata values
+- Metadata filtering and querying operations
+- Lock management for concurrent metadata operations
+
+Metadata in Dify allows users to add custom fields to datasets and documents,
+enabling rich filtering and search capabilities. Metadata can be of various
+types (string, number, date, boolean, etc.) and can be used to categorize
+and filter documents within a dataset.
+
+This test suite ensures:
+- Correct creation of metadata fields with validation
+- Proper updating of metadata names and values
+- Accurate deletion of metadata fields
+- Built-in field management (enable/disable)
+- Document metadata updates (partial and full)
+- Lock management for concurrent operations
+- Metadata querying and filtering functionality
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The MetadataService is a critical component in the Dify platform's metadata
+management system. It serves as the primary interface for all metadata-related
+operations, including field definitions and document-level metadata values.
+
+Key Concepts:
+1. DatasetMetadata: Defines a metadata field for a dataset. Each metadata
+ field has a name, type, and is associated with a specific dataset.
+
+2. DatasetMetadataBinding: Links metadata fields to documents. This allows
+ tracking which documents have which metadata fields assigned.
+
+3. Document Metadata: The actual metadata values stored on documents. This
+ is stored as a JSON object in the document's doc_metadata field.
+
+4. Built-in Fields: System-defined metadata fields that are automatically
+ available when enabled (document_name, uploader, upload_date, etc.).
+
+5. Lock Management: Redis-based locking to prevent concurrent metadata
+ operations that could cause data corruption.
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. CRUD Operations:
+ - Creating metadata fields with validation
+ - Reading/retrieving metadata fields
+ - Updating metadata field names
+ - Deleting metadata fields
+
+2. Built-in Field Management:
+ - Enabling built-in fields
+ - Disabling built-in fields
+ - Getting built-in field definitions
+
+3. Document Metadata Operations:
+ - Updating document metadata (partial and full)
+ - Managing metadata bindings
+ - Handling built-in field updates
+
+4. Lock Management:
+ - Acquiring locks for dataset operations
+ - Acquiring locks for document operations
+ - Handling lock conflicts
+
+5. Error Handling:
+ - Validation errors (name length, duplicates)
+ - Not found errors
+ - Lock conflict errors
+
+================================================================================
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.rag.index_processor.constant.built_in_field import BuiltInField
+from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
+from services.entities.knowledge_entities.knowledge_entities import (
+ MetadataArgs,
+ MetadataValue,
+)
+from services.metadata_service import MetadataService
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of models changes, we only need to
+# update the factory methods rather than every individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class MetadataTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for metadata service tests.
+
+ This factory provides static methods to create mock objects for:
+ - DatasetMetadata instances
+ - DatasetMetadataBinding instances
+ - Dataset instances
+ - Document instances
+ - MetadataArgs and MetadataOperationData entities
+ - User and tenant context
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_metadata_mock(
+ metadata_id: str = "metadata-123",
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ name: str = "category",
+ metadata_type: str = "string",
+ created_by: str = "user-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetMetadata with specified attributes.
+
+ Args:
+ metadata_id: Unique identifier for the metadata field
+ dataset_id: ID of the dataset this metadata belongs to
+ tenant_id: Tenant identifier
+ name: Name of the metadata field
+ metadata_type: Type of metadata (string, number, date, etc.)
+ created_by: ID of the user who created the metadata
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetMetadata instance
+ """
+ metadata = Mock(spec=DatasetMetadata)
+ metadata.id = metadata_id
+ metadata.dataset_id = dataset_id
+ metadata.tenant_id = tenant_id
+ metadata.name = name
+ metadata.type = metadata_type
+ metadata.created_by = created_by
+ metadata.updated_by = None
+ metadata.updated_at = None
+ for key, value in kwargs.items():
+ setattr(metadata, key, value)
+ return metadata
+
+ @staticmethod
+ def create_metadata_binding_mock(
+ binding_id: str = "binding-123",
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ metadata_id: str = "metadata-123",
+ document_id: str = "document-123",
+ created_by: str = "user-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetMetadataBinding with specified attributes.
+
+ Args:
+ binding_id: Unique identifier for the binding
+ dataset_id: ID of the dataset
+ tenant_id: Tenant identifier
+ metadata_id: ID of the metadata field
+ document_id: ID of the document
+ created_by: ID of the user who created the binding
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetMetadataBinding instance
+ """
+ binding = Mock(spec=DatasetMetadataBinding)
+ binding.id = binding_id
+ binding.dataset_id = dataset_id
+ binding.tenant_id = tenant_id
+ binding.metadata_id = metadata_id
+ binding.document_id = document_id
+ binding.created_by = created_by
+ for key, value in kwargs.items():
+ setattr(binding, key, value)
+ return binding
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ built_in_field_enabled: bool = False,
+ doc_metadata: list | None = None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset with specified attributes.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ tenant_id: Tenant identifier
+ built_in_field_enabled: Whether built-in fields are enabled
+ doc_metadata: List of metadata field definitions
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ dataset.tenant_id = tenant_id
+ dataset.built_in_field_enabled = built_in_field_enabled
+ dataset.doc_metadata = doc_metadata or []
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_document_mock(
+ document_id: str = "document-123",
+ dataset_id: str = "dataset-123",
+ name: str = "Test Document",
+ doc_metadata: dict | None = None,
+ uploader: str = "user-123",
+ data_source_type: str = "upload_file",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Document with specified attributes.
+
+ Args:
+ document_id: Unique identifier for the document
+ dataset_id: ID of the dataset this document belongs to
+ name: Name of the document
+ doc_metadata: Dictionary of metadata values
+ uploader: ID of the user who uploaded the document
+ data_source_type: Type of data source
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Document instance
+ """
+ document = Mock()
+ document.id = document_id
+ document.dataset_id = dataset_id
+ document.name = name
+ document.doc_metadata = doc_metadata or {}
+ document.uploader = uploader
+ document.data_source_type = data_source_type
+
+ # Mock datetime objects for upload_date and last_update_date
+
+ document.upload_date = Mock()
+ document.upload_date.timestamp.return_value = 1234567890.0
+ document.last_update_date = Mock()
+ document.last_update_date.timestamp.return_value = 1234567890.0
+
+ for key, value in kwargs.items():
+ setattr(document, key, value)
+ return document
+
+ @staticmethod
+ def create_metadata_args_mock(
+ name: str = "category",
+ metadata_type: str = "string",
+ ) -> Mock:
+ """
+ Create a mock MetadataArgs entity.
+
+ Args:
+ name: Name of the metadata field
+ metadata_type: Type of metadata
+
+ Returns:
+ Mock object configured as a MetadataArgs instance
+ """
+ metadata_args = Mock(spec=MetadataArgs)
+ metadata_args.name = name
+ metadata_args.type = metadata_type
+ return metadata_args
+
+ @staticmethod
+ def create_metadata_value_mock(
+ metadata_id: str = "metadata-123",
+ name: str = "category",
+ value: str = "test",
+ ) -> Mock:
+ """
+ Create a mock MetadataValue entity.
+
+ Args:
+ metadata_id: ID of the metadata field
+ name: Name of the metadata field
+ value: Value of the metadata
+
+ Returns:
+ Mock object configured as a MetadataValue instance
+ """
+ metadata_value = Mock(spec=MetadataValue)
+ metadata_value.id = metadata_id
+ metadata_value.name = name
+ metadata_value.value = value
+ return metadata_value
+
+
+# ============================================================================
+# Tests for create_metadata
+# ============================================================================
+
+
+class TestMetadataServiceCreateMetadata:
+ """
+ Comprehensive unit tests for MetadataService.create_metadata method.
+
+ This test class covers the metadata field creation functionality,
+ including validation, duplicate checking, and database operations.
+
+ The create_metadata method:
+ 1. Validates metadata name length (max 255 characters)
+ 2. Checks for duplicate metadata names within the dataset
+ 3. Checks for conflicts with built-in field names
+ 4. Creates a new DatasetMetadata instance
+ 5. Adds it to the database session and commits
+ 6. Returns the created metadata
+
+ Test scenarios include:
+ - Successful creation with valid data
+ - Name length validation
+ - Duplicate name detection
+ - Built-in field name conflicts
+ - Database transaction handling
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction and execution
+ - Add operations for new metadata
+ - Commit operations for transaction completion
+ """
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """
+ Mock current user and tenant context.
+
+ Provides mocked current_account_with_tenant function that returns
+ a user and tenant ID for testing authentication and authorization.
+ """
+ with patch("services.metadata_service.current_account_with_tenant") as mock_get_user:
+ mock_user = Mock()
+ mock_user.id = "user-123"
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_create_metadata_success(self, mock_db_session, mock_current_user):
+ """
+ Test successful creation of a metadata field.
+
+ Verifies that when all validation passes, a new metadata field
+ is created and persisted to the database.
+
+ This test ensures:
+ - Metadata name validation passes
+ - No duplicate name exists
+ - No built-in field conflict
+ - New metadata is added to database
+ - Transaction is committed
+ - Created metadata is returned
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
+
+ # Mock query to return None (no existing metadata with same name)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock BuiltInField enum iteration
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act
+ result = MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Assert
+ assert result is not None
+ assert isinstance(result, DatasetMetadata)
+
+ # Verify query was made to check for duplicates
+ mock_db_session.query.assert_called()
+ mock_query.filter_by.assert_called()
+
+ # Verify metadata was added and committed
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_metadata_name_too_long_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name exceeds 255 characters.
+
+ Verifies that when a metadata name is longer than 255 characters,
+ a ValueError is raised with an appropriate message.
+
+ This test ensures:
+ - Name length validation is enforced
+ - Error message is clear and descriptive
+ - No database operations are performed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ long_name = "a" * 256 # 256 characters (exceeds limit)
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name=long_name, metadata_type="string")
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no database operations were performed
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_metadata_duplicate_name_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name already exists.
+
+ Verifies that when a metadata field with the same name already exists
+ in the dataset, a ValueError is raised.
+
+ This test ensures:
+ - Duplicate name detection works correctly
+ - Error message is clear
+ - No new metadata is created
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
+
+ # Mock existing metadata with same name
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = existing_metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name already exists"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no new metadata was added
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_metadata_builtin_field_conflict_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name conflicts with built-in field.
+
+ Verifies that when a metadata name matches a built-in field name,
+ a ValueError is raised.
+
+ This test ensures:
+ - Built-in field name conflicts are detected
+ - Error message is clear
+ - No new metadata is created
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(
+ name=BuiltInField.document_name, metadata_type="string"
+ )
+
+ # Mock query to return None (no duplicate in database)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock BuiltInField to include the conflicting name
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_field = Mock()
+ mock_field.value = BuiltInField.document_name
+ mock_builtin.__iter__ = Mock(return_value=iter([mock_field]))
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no new metadata was added
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+
+# ============================================================================
+# Tests for update_metadata_name
+# ============================================================================
+
+
+class TestMetadataServiceUpdateMetadataName:
+ """
+ Comprehensive unit tests for MetadataService.update_metadata_name method.
+
+ This test class covers the metadata field name update functionality,
+ including validation, duplicate checking, and document metadata updates.
+
+ The update_metadata_name method:
+ 1. Validates new name length (max 255 characters)
+ 2. Checks for duplicate names
+ 3. Checks for built-in field conflicts
+ 4. Acquires a lock for the dataset
+ 5. Updates the metadata name
+ 6. Updates all related document metadata
+ 7. Releases the lock
+ 8. Returns the updated metadata
+
+ Test scenarios include:
+ - Successful name update
+ - Name length validation
+ - Duplicate name detection
+ - Built-in field conflicts
+ - Lock management
+ - Document metadata updates
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("services.metadata_service.current_account_with_tenant") as mock_get_user:
+ mock_user = Mock()
+ mock_user.id = "user-123"
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ mock_redis.get.return_value = None # No existing lock
+ mock_redis.set.return_value = True
+ mock_redis.delete.return_value = True
+ yield mock_redis
+
+ def test_update_metadata_name_success(self, mock_db_session, mock_current_user, mock_redis_client):
+ """
+ Test successful update of metadata field name.
+
+ Verifies that when all validation passes, the metadata name is
+ updated and all related document metadata is updated accordingly.
+
+ This test ensures:
+ - Name validation passes
+ - Lock is acquired and released
+ - Metadata name is updated
+ - Related document metadata is updated
+ - Transaction is committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "metadata-123"
+ new_name = "updated_category"
+
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
+
+ # Mock query for duplicate check (no duplicate)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock metadata retrieval
+ def query_side_effect(model):
+ if model == DatasetMetadata:
+ mock_meta_query = Mock()
+ mock_meta_query.filter_by.return_value = mock_meta_query
+ mock_meta_query.first.return_value = existing_metadata
+ return mock_meta_query
+ return mock_query
+
+ mock_db_session.query.side_effect = query_side_effect
+
+ # Mock no metadata bindings (no documents to update)
+ mock_binding_query = Mock()
+ mock_binding_query.filter_by.return_value = mock_binding_query
+ mock_binding_query.all.return_value = []
+
+ # Mock BuiltInField enum
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act
+ result = MetadataService.update_metadata_name(dataset_id, metadata_id, new_name)
+
+ # Assert
+ assert result is not None
+ assert result.name == new_name
+
+ # Verify lock was acquired and released
+ mock_redis_client.get.assert_called()
+ mock_redis_client.set.assert_called()
+ mock_redis_client.delete.assert_called()
+
+ # Verify metadata was updated and committed
+ mock_db_session.commit.assert_called()
+
+ def test_update_metadata_name_not_found_error(self, mock_db_session, mock_current_user, mock_redis_client):
+ """
+ Test error handling when metadata is not found.
+
+ Verifies that when the metadata ID doesn't exist, a ValueError
+ is raised with an appropriate message.
+
+ This test ensures:
+ - Not found error is handled correctly
+ - Lock is properly released even on error
+ - No updates are committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "non-existent-metadata"
+ new_name = "updated_category"
+
+ # Mock query for duplicate check (no duplicate)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock metadata retrieval to return None
+ def query_side_effect(model):
+ if model == DatasetMetadata:
+ mock_meta_query = Mock()
+ mock_meta_query.filter_by.return_value = mock_meta_query
+ mock_meta_query.first.return_value = None # Not found
+ return mock_meta_query
+ return mock_query
+
+ mock_db_session.query.side_effect = query_side_effect
+
+ # Mock BuiltInField enum
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata not found"):
+ MetadataService.update_metadata_name(dataset_id, metadata_id, new_name)
+
+ # Verify lock was released
+ mock_redis_client.delete.assert_called()
+
+
+# ============================================================================
+# Tests for delete_metadata
+# ============================================================================
+
+
+class TestMetadataServiceDeleteMetadata:
+ """
+ Comprehensive unit tests for MetadataService.delete_metadata method.
+
+ This test class covers the metadata field deletion functionality,
+ including document metadata cleanup and lock management.
+
+ The delete_metadata method:
+ 1. Acquires a lock for the dataset
+ 2. Retrieves the metadata to delete
+ 3. Deletes the metadata from the database
+ 4. Removes metadata from all related documents
+ 5. Releases the lock
+ 6. Returns the deleted metadata
+
+ Test scenarios include:
+ - Successful deletion
+ - Not found error handling
+ - Document metadata cleanup
+ - Lock management
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ mock_redis.get.return_value = None
+ mock_redis.set.return_value = True
+ mock_redis.delete.return_value = True
+ yield mock_redis
+
+ def test_delete_metadata_success(self, mock_db_session, mock_redis_client):
+ """
+ Test successful deletion of a metadata field.
+
+ Verifies that when the metadata exists, it is deleted and all
+ related document metadata is cleaned up.
+
+ This test ensures:
+ - Lock is acquired and released
+ - Metadata is deleted from database
+ - Related document metadata is removed
+ - Transaction is committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "metadata-123"
+
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
+
+ # Mock metadata retrieval
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = existing_metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Mock no metadata bindings (no documents to update)
+ mock_binding_query = Mock()
+ mock_binding_query.filter_by.return_value = mock_binding_query
+ mock_binding_query.all.return_value = []
+
+ # Act
+ result = MetadataService.delete_metadata(dataset_id, metadata_id)
+
+ # Assert
+ assert result == existing_metadata
+
+ # Verify lock was acquired and released
+ mock_redis_client.get.assert_called()
+ mock_redis_client.set.assert_called()
+ mock_redis_client.delete.assert_called()
+
+ # Verify metadata was deleted and committed
+ mock_db_session.delete.assert_called_once_with(existing_metadata)
+ mock_db_session.commit.assert_called()
+
+ def test_delete_metadata_not_found_error(self, mock_db_session, mock_redis_client):
+ """
+ Test error handling when metadata is not found.
+
+ Verifies that when the metadata ID doesn't exist, a ValueError
+ is raised and the lock is properly released.
+
+ This test ensures:
+ - Not found error is handled correctly
+ - Lock is released even on error
+ - No deletion is performed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "non-existent-metadata"
+
+ # Mock metadata retrieval to return None
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata not found"):
+ MetadataService.delete_metadata(dataset_id, metadata_id)
+
+ # Verify lock was released
+ mock_redis_client.delete.assert_called()
+
+ # Verify no deletion was performed
+ mock_db_session.delete.assert_not_called()
+
+
+# ============================================================================
+# Tests for get_built_in_fields
+# ============================================================================
+
+
+class TestMetadataServiceGetBuiltInFields:
+ """
+ Comprehensive unit tests for MetadataService.get_built_in_fields method.
+
+ This test class covers the built-in field retrieval functionality.
+
+ The get_built_in_fields method:
+ 1. Returns a list of built-in field definitions
+ 2. Each definition includes name and type
+
+ Test scenarios include:
+ - Successful retrieval of built-in fields
+ - Correct field definitions
+ """
+
+ def test_get_built_in_fields_success(self):
+ """
+ Test successful retrieval of built-in fields.
+
+ Verifies that the method returns the correct list of built-in
+ field definitions with proper structure.
+
+ This test ensures:
+ - All built-in fields are returned
+ - Each field has name and type
+ - Field definitions are correct
+ """
+ # Act
+ result = MetadataService.get_built_in_fields()
+
+ # Assert
+ assert isinstance(result, list)
+ assert len(result) > 0
+
+ # Verify each field has required properties
+ for field in result:
+ assert "name" in field
+ assert "type" in field
+ assert isinstance(field["name"], str)
+ assert isinstance(field["type"], str)
+
+ # Verify specific built-in fields are present
+ field_names = [field["name"] for field in result]
+ assert BuiltInField.document_name in field_names
+ assert BuiltInField.uploader in field_names
+
+
+# ============================================================================
+# Tests for knowledge_base_metadata_lock_check
+# ============================================================================
+
+
+class TestMetadataServiceLockCheck:
+ """
+ Comprehensive unit tests for MetadataService.knowledge_base_metadata_lock_check method.
+
+ This test class covers the lock management functionality for preventing
+ concurrent metadata operations.
+
+ The knowledge_base_metadata_lock_check method:
+ 1. Checks if a lock exists for the dataset or document
+ 2. Raises ValueError if lock exists (operation in progress)
+ 3. Sets a lock with expiration time (3600 seconds)
+ 4. Supports both dataset-level and document-level locks
+
+ Test scenarios include:
+ - Successful lock acquisition
+ - Lock conflict detection
+ - Dataset-level locks
+ - Document-level locks
+ """
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ def test_lock_check_dataset_success(self, mock_redis_client):
+ """
+ Test successful lock acquisition for dataset operations.
+
+ Verifies that when no lock exists, a new lock is acquired
+ for the dataset.
+
+ This test ensures:
+ - Lock check passes when no lock exists
+ - Lock is set with correct key and expiration
+ - No error is raised
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ mock_redis_client.get.return_value = None # No existing lock
+
+ # Act (should not raise)
+ MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+
+ # Assert
+ mock_redis_client.get.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}")
+ mock_redis_client.set.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}", 1, ex=3600)
+
+ def test_lock_check_dataset_conflict_error(self, mock_redis_client):
+ """
+ Test error handling when dataset lock already exists.
+
+ Verifies that when a lock exists for the dataset, a ValueError
+ is raised with an appropriate message.
+
+ This test ensures:
+ - Lock conflict is detected
+ - Error message is clear
+ - No new lock is set
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ mock_redis_client.get.return_value = "1" # Lock exists
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Another knowledge base metadata operation is running"):
+ MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+
+ # Verify lock was checked but not set
+ mock_redis_client.get.assert_called_once()
+ mock_redis_client.set.assert_not_called()
+
+ def test_lock_check_document_success(self, mock_redis_client):
+ """
+ Test successful lock acquisition for document operations.
+
+ Verifies that when no lock exists, a new lock is acquired
+ for the document.
+
+ This test ensures:
+ - Lock check passes when no lock exists
+ - Lock is set with correct key and expiration
+ - No error is raised
+ """
+ # Arrange
+ document_id = "document-123"
+ mock_redis_client.get.return_value = None # No existing lock
+
+ # Act (should not raise)
+ MetadataService.knowledge_base_metadata_lock_check(None, document_id)
+
+ # Assert
+ mock_redis_client.get.assert_called_once_with(f"document_metadata_lock_{document_id}")
+ mock_redis_client.set.assert_called_once_with(f"document_metadata_lock_{document_id}", 1, ex=3600)
+
+
+# ============================================================================
+# Tests for get_dataset_metadatas
+# ============================================================================
+
+
+class TestMetadataServiceGetDatasetMetadatas:
+ """
+ Comprehensive unit tests for MetadataService.get_dataset_metadatas method.
+
+ This test class covers the metadata retrieval functionality for datasets.
+
+ The get_dataset_metadatas method:
+ 1. Retrieves all metadata fields for a dataset
+ 2. Excludes built-in fields from the list
+ 3. Includes usage count for each metadata field
+ 4. Returns built-in field enabled status
+
+ Test scenarios include:
+ - Successful retrieval with metadata fields
+ - Empty metadata list
+ - Built-in field filtering
+ - Usage count calculation
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_metadatas_success(self, mock_db_session):
+ """
+ Test successful retrieval of dataset metadata fields.
+
+ Verifies that all metadata fields are returned with correct
+ structure and usage counts.
+
+ This test ensures:
+ - All metadata fields are included
+ - Built-in fields are excluded
+ - Usage counts are calculated correctly
+ - Built-in field status is included
+ """
+ # Arrange
+ dataset = MetadataTestDataFactory.create_dataset_mock(
+ dataset_id="dataset-123",
+ built_in_field_enabled=True,
+ doc_metadata=[
+ {"id": "metadata-1", "name": "category", "type": "string"},
+ {"id": "metadata-2", "name": "priority", "type": "number"},
+ {"id": "built-in", "name": "document_name", "type": "string"},
+ ],
+ )
+
+ # Mock usage count queries
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.count.return_value = 5 # 5 documents use this metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = MetadataService.get_dataset_metadatas(dataset)
+
+ # Assert
+ assert "doc_metadata" in result
+ assert "built_in_field_enabled" in result
+ assert result["built_in_field_enabled"] is True
+
+ # Verify built-in fields are excluded
+ metadata_ids = [meta["id"] for meta in result["doc_metadata"]]
+ assert "built-in" not in metadata_ids
+
+ # Verify all custom metadata fields are included
+ assert len(result["doc_metadata"]) == 2
+
+ # Verify usage counts are included
+ for meta in result["doc_metadata"]:
+ assert "count" in meta
+ assert meta["count"] == 5
+
+
+# ============================================================================
+# Additional Documentation and Notes
+# ============================================================================
+#
+# This test suite covers the core metadata CRUD operations and basic
+# filtering functionality. Additional test scenarios that could be added:
+#
+# 1. enable_built_in_field / disable_built_in_field:
+# - Testing built-in field enablement
+# - Testing built-in field disablement
+# - Testing document metadata updates when enabling/disabling
+#
+# 2. update_documents_metadata:
+# - Testing partial updates
+# - Testing full updates
+# - Testing metadata binding creation
+# - Testing built-in field updates
+#
+# 3. Metadata Filtering and Querying:
+# - Testing metadata-based document filtering
+# - Testing complex metadata queries
+# - Testing metadata value retrieval
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
From 8d8800e632a417d21ebaa06e784e66022596a4fc Mon Sep 17 00:00:00 2001
From: majinghe <42570491+majinghe@users.noreply.github.com>
Date: Thu, 27 Nov 2025 11:01:14 +0800
Subject: [PATCH 50/63] upgrade docker compose milvus version to 2.6.0 to fix
installation error (#26618)
Co-authored-by: crazywoola <427733928@qq.com>
---
docker/docker-compose-template.yaml | 2 +-
docker/docker-compose.yaml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index 975c92693a..703a60ef67 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -676,7 +676,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
- image: milvusdb/milvus:v2.5.15
+ image: milvusdb/milvus:v2.6.3
profiles:
- milvus
command: ["milvus", "run", "standalone"]
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 17f33bbf72..de2e3943fe 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -1311,7 +1311,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
- image: milvusdb/milvus:v2.5.15
+ image: milvusdb/milvus:v2.6.3
profiles:
- milvus
command: ["milvus", "run", "standalone"]
From f9b4c3134441f4c2547ad4613d2fb1800e7e1ab8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?=
Date: Thu, 27 Nov 2025 11:22:49 +0800
Subject: [PATCH 51/63] fix: MCP tool time configuration not work (#28740)
---
web/app/components/tools/mcp/modal.tsx | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx
index 68f97703bf..836fc5e0aa 100644
--- a/web/app/components/tools/mcp/modal.tsx
+++ b/web/app/components/tools/mcp/modal.tsx
@@ -99,8 +99,8 @@ const MCPModal = ({
const [appIcon, setAppIcon] = useState(() => getIcon(data))
const [showAppIconPicker, setShowAppIconPicker] = useState(false)
const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '')
- const [timeout, setMcpTimeout] = React.useState(data?.timeout || 30)
- const [sseReadTimeout, setSseReadTimeout] = React.useState(data?.sse_read_timeout || 300)
+ const [timeout, setMcpTimeout] = React.useState(data?.configuration?.timeout || 30)
+ const [sseReadTimeout, setSseReadTimeout] = React.useState(data?.configuration?.sse_read_timeout || 300)
const [headers, setHeaders] = React.useState(
Object.entries(data?.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })),
)
@@ -118,8 +118,8 @@ const MCPModal = ({
setUrl(data.server_url || '')
setName(data.name || '')
setServerIdentifier(data.server_identifier || '')
- setMcpTimeout(data.timeout || 30)
- setSseReadTimeout(data.sse_read_timeout || 300)
+ setMcpTimeout(data.configuration?.timeout || 30)
+ setSseReadTimeout(data.configuration?.sse_read_timeout || 300)
setHeaders(Object.entries(data.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })))
setAppIcon(getIcon(data))
setIsDynamicRegistration(data.is_dynamic_registration)
From 6deabfdad38f4f7ed4ff9d2f945e2a8385316ea6 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Thu, 27 Nov 2025 11:23:20 +0800
Subject: [PATCH 52/63] Use naive_utc_now in graph engine tests (#28735)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../event_management/test_event_handlers.py | 5 ++---
.../graph_engine/orchestration/test_dispatcher.py | 10 +++++-----
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
index 2b8f04979d..5d17b7a243 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
@@ -2,8 +2,6 @@
from __future__ import annotations
-from datetime import datetime
-
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
@@ -16,6 +14,7 @@ from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
from core.workflow.runtime import GraphRuntimeState, VariablePool
+from libs.datetime_utils import naive_utc_now
class _StubEdgeProcessor:
@@ -75,7 +74,7 @@ def test_retry_does_not_emit_additional_start_event() -> None:
execution_id = "exec-1"
node_type = NodeType.CODE
- start_time = datetime.utcnow()
+ start_time = naive_utc_now()
start_event = NodeRunStartedEvent(
id=execution_id,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
index e6d4508fdf..c1fc4acd73 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
@@ -3,7 +3,6 @@
from __future__ import annotations
import queue
-from datetime import datetime
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@@ -18,6 +17,7 @@ from core.workflow.graph_events import (
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
+from libs.datetime_utils import naive_utc_now
def test_dispatcher_should_consume_remains_events_after_pause():
@@ -109,7 +109,7 @@ def _make_started_event() -> NodeRunStartedEvent:
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
)
@@ -119,7 +119,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
@@ -153,7 +153,7 @@ def test_dispatcher_drain_event_queue():
node_id="node-1",
node_type=NodeType.CODE,
node_title="Code",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
),
NodeRunPauseRequestedEvent(
id="pause-event",
@@ -165,7 +165,7 @@ def test_dispatcher_drain_event_queue():
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
),
]
From 0309545ff15d2a79087a5875d99c036c301ccc74 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 22:23:55 -0500
Subject: [PATCH 53/63] Feat/test script of workflow service (#28726)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../services/test_workflow_service.py | 1114 +++++++++++++++++
1 file changed, 1114 insertions(+)
create mode 100644 api/tests/unit_tests/services/test_workflow_service.py
diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py
new file mode 100644
index 0000000000..ae5b194afb
--- /dev/null
+++ b/api/tests/unit_tests/services/test_workflow_service.py
@@ -0,0 +1,1114 @@
+"""
+Unit tests for WorkflowService.
+
+This test suite covers:
+- Workflow creation from template
+- Workflow validation (graph and features structure)
+- Draft/publish transitions
+- Version management
+- Execution triggering
+"""
+
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.workflow.enums import NodeType
+from libs.datetime_utils import naive_utc_now
+from models.model import App, AppMode
+from models.workflow import Workflow, WorkflowType
+from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
+from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
+from services.workflow_service import WorkflowService
+
+
+class TestWorkflowAssociatedDataFactory:
+ """
+ Factory class for creating test data and mock objects for workflow service tests.
+
+ This factory provides reusable methods to create mock objects for:
+ - App models with configurable attributes
+ - Workflow models with graph and feature configurations
+ - Account models for user authentication
+ - Valid workflow graph structures for testing
+
+ All factory methods return MagicMock objects that simulate database models
+ without requiring actual database connections.
+ """
+
+ @staticmethod
+ def create_app_mock(
+ app_id: str = "app-123",
+ tenant_id: str = "tenant-456",
+ mode: str = AppMode.WORKFLOW.value,
+ workflow_id: str | None = None,
+ **kwargs,
+ ) -> MagicMock:
+ """
+ Create a mock App with specified attributes.
+
+ Args:
+ app_id: Unique identifier for the app
+ tenant_id: Workspace/tenant identifier
+ mode: App mode (workflow, chat, completion, etc.)
+ workflow_id: Optional ID of the published workflow
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ MagicMock object configured as an App model
+ """
+ app = MagicMock(spec=App)
+ app.id = app_id
+ app.tenant_id = tenant_id
+ app.mode = mode
+ app.workflow_id = workflow_id
+ for key, value in kwargs.items():
+ setattr(app, key, value)
+ return app
+
+ @staticmethod
+ def create_workflow_mock(
+ workflow_id: str = "workflow-789",
+ tenant_id: str = "tenant-456",
+ app_id: str = "app-123",
+ version: str = Workflow.VERSION_DRAFT,
+ workflow_type: str = WorkflowType.WORKFLOW.value,
+ graph: dict | None = None,
+ features: dict | None = None,
+ unique_hash: str | None = None,
+ **kwargs,
+ ) -> MagicMock:
+ """
+ Create a mock Workflow with specified attributes.
+
+ Args:
+ workflow_id: Unique identifier for the workflow
+ tenant_id: Workspace/tenant identifier
+ app_id: Associated app identifier
+ version: Workflow version ("draft" or timestamp-based version)
+ workflow_type: Type of workflow (workflow, chat, rag-pipeline)
+ graph: Workflow graph structure containing nodes and edges
+ features: Feature configuration (file upload, text-to-speech, etc.)
+ unique_hash: Hash for optimistic locking during updates
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ MagicMock object configured as a Workflow model with graph/features
+ """
+ workflow = MagicMock(spec=Workflow)
+ workflow.id = workflow_id
+ workflow.tenant_id = tenant_id
+ workflow.app_id = app_id
+ workflow.version = version
+ workflow.type = workflow_type
+
+ # Set up graph and features with defaults if not provided
+ # Graph contains the workflow structure (nodes and their connections)
+ if graph is None:
+ graph = {"nodes": [], "edges": []}
+ # Features contain app-level configurations like file upload settings
+ if features is None:
+ features = {}
+
+ workflow.graph = json.dumps(graph)
+ workflow.features = json.dumps(features)
+ workflow.graph_dict = graph
+ workflow.features_dict = features
+ workflow.unique_hash = unique_hash or "test-hash-123"
+ workflow.environment_variables = []
+ workflow.conversation_variables = []
+ workflow.rag_pipeline_variables = []
+ workflow.created_by = "user-123"
+ workflow.updated_by = None
+ workflow.created_at = naive_utc_now()
+ workflow.updated_at = naive_utc_now()
+
+ # Mock walk_nodes method to iterate through workflow nodes
+ # This is used by the service to traverse and validate workflow structure
+ def walk_nodes_side_effect(specific_node_type=None):
+ nodes = graph.get("nodes", [])
+ # Filter by node type if specified (e.g., only LLM nodes)
+ if specific_node_type:
+ return (
+ (node["id"], node["data"])
+ for node in nodes
+ if node.get("data", {}).get("type") == specific_node_type.value
+ )
+ # Return all nodes if no filter specified
+ return ((node["id"], node["data"]) for node in nodes)
+
+ workflow.walk_nodes = walk_nodes_side_effect
+
+ for key, value in kwargs.items():
+ setattr(workflow, key, value)
+ return workflow
+
+ @staticmethod
+ def create_account_mock(account_id: str = "user-123", **kwargs) -> MagicMock:
+ """Create a mock Account with specified attributes."""
+ account = MagicMock()
+ account.id = account_id
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_valid_workflow_graph(include_start: bool = True, include_trigger: bool = False) -> dict:
+ """
+ Create a valid workflow graph structure for testing.
+
+ Args:
+ include_start: Whether to include a START node (for regular workflows)
+ include_trigger: Whether to include trigger nodes (webhook, schedule, etc.)
+
+ Returns:
+ Dictionary containing nodes and edges arrays representing workflow graph
+
+ Note:
+ Start nodes and trigger nodes cannot coexist in the same workflow.
+ This is validated by the workflow service.
+ """
+ nodes = []
+ edges = []
+
+ # Add START node for regular workflows (user-initiated)
+ if include_start:
+ nodes.append(
+ {
+ "id": "start",
+ "data": {
+ "type": NodeType.START.value,
+ "title": "START",
+ "variables": [],
+ },
+ }
+ )
+
+ # Add trigger node for event-driven workflows (webhook, schedule, etc.)
+ if include_trigger:
+ nodes.append(
+ {
+ "id": "trigger-1",
+ "data": {
+ "type": "http-request",
+ "title": "HTTP Request Trigger",
+ },
+ }
+ )
+
+ # Add an LLM node as a sample processing node
+ # This represents an AI model interaction in the workflow
+ nodes.append(
+ {
+ "id": "llm-1",
+ "data": {
+ "type": NodeType.LLM.value,
+ "title": "LLM",
+ "model": {
+ "provider": "openai",
+ "name": "gpt-4",
+ },
+ },
+ }
+ )
+
+ return {"nodes": nodes, "edges": edges}
+
+
+class TestWorkflowService:
+ """
+ Comprehensive unit tests for WorkflowService methods.
+
+ This test suite covers:
+ - Workflow creation from template
+ - Workflow validation (graph and features)
+ - Draft/publish transitions
+ - Version management
+ - Workflow deletion and error handling
+ """
+
+ @pytest.fixture
+ def workflow_service(self):
+ """
+ Create a WorkflowService instance with mocked dependencies.
+
+ This fixture patches the database to avoid real database connections
+ during testing. Each test gets a fresh service instance.
+ """
+ with patch("services.workflow_service.db"):
+ service = WorkflowService()
+ return service
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides mock implementations of:
+ - session.add(): Adding new records
+ - session.commit(): Committing transactions
+ - session.query(): Querying database
+ - session.execute(): Executing SQL statements
+ """
+ with patch("services.workflow_service.db") as mock_db:
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.query = MagicMock()
+ mock_session.execute = MagicMock()
+ yield mock_db
+
+ @pytest.fixture
+ def mock_sqlalchemy_session(self):
+ """
+ Mock SQLAlchemy Session for publish_workflow tests.
+
+ This is a separate fixture because publish_workflow uses
+ SQLAlchemy's Session class directly rather than the Flask-SQLAlchemy
+ db.session object.
+ """
+ mock_session = MagicMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.scalar = MagicMock()
+ return mock_session
+
+ # ==================== Workflow Existence Tests ====================
+ # These tests verify the service can check if a draft workflow exists
+
+ def test_is_workflow_exist_returns_true(self, workflow_service, mock_db_session):
+ """
+ Test is_workflow_exist returns True when draft workflow exists.
+
+ Verifies that the service correctly identifies when an app has a draft workflow.
+ This is used to determine whether to create or update a workflow.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock the database query to return True
+ mock_db_session.session.execute.return_value.scalar_one.return_value = True
+
+ result = workflow_service.is_workflow_exist(app)
+
+ assert result is True
+
+ def test_is_workflow_exist_returns_false(self, workflow_service, mock_db_session):
+ """Test is_workflow_exist returns False when no draft workflow exists."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock the database query to return False
+ mock_db_session.session.execute.return_value.scalar_one.return_value = False
+
+ result = workflow_service.is_workflow_exist(app)
+
+ assert result is False
+
+ # ==================== Get Draft Workflow Tests ====================
+ # These tests verify retrieval of draft workflows (version="draft")
+
+ def test_get_draft_workflow_success(self, workflow_service, mock_db_session):
+ """
+ Test get_draft_workflow returns draft workflow successfully.
+
+ Draft workflows are the working copy that users edit before publishing.
+ Each app can have only one draft workflow at a time.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_draft_workflow(app)
+
+ assert result == mock_workflow
+
+ def test_get_draft_workflow_returns_none(self, workflow_service, mock_db_session):
+ """Test get_draft_workflow returns None when no draft exists."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock database query to return None
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ result = workflow_service.get_draft_workflow(app)
+
+ assert result is None
+
+ def test_get_draft_workflow_with_workflow_id(self, workflow_service, mock_db_session):
+ """Test get_draft_workflow with workflow_id calls get_published_workflow_by_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
+
+ assert result == mock_workflow
+
+ # ==================== Get Published Workflow Tests ====================
+ # These tests verify retrieval of published workflows (versioned snapshots)
+
+ def test_get_published_workflow_by_id_success(self, workflow_service, mock_db_session):
+ """Test get_published_workflow_by_id returns published workflow."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ assert result == mock_workflow
+
+ def test_get_published_workflow_by_id_raises_error_for_draft(self, workflow_service, mock_db_session):
+ """
+ Test get_published_workflow_by_id raises error when workflow is draft.
+
+ This prevents using draft workflows in production contexts where only
+ published, stable versions should be used (e.g., API execution).
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(
+ workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
+ )
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with pytest.raises(IsDraftWorkflowError):
+ workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ def test_get_published_workflow_by_id_returns_none(self, workflow_service, mock_db_session):
+ """Test get_published_workflow_by_id returns None when workflow not found."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "nonexistent-workflow"
+
+ # Mock database query to return None
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ result = workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ assert result is None
+
+ def test_get_published_workflow_success(self, workflow_service, mock_db_session):
+ """Test get_published_workflow returns published workflow."""
+ workflow_id = "workflow-123"
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_published_workflow(app)
+
+ assert result == mock_workflow
+
+ def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service):
+ """Test get_published_workflow returns None when app has no workflow_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
+
+ result = workflow_service.get_published_workflow(app)
+
+ assert result is None
+
+ # ==================== Sync Draft Workflow Tests ====================
+ # These tests verify creating and updating draft workflows with validation
+
+ def test_sync_draft_workflow_creates_new_draft(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow creates new draft workflow when none exists.
+
+ When a user first creates a workflow app, this creates the initial draft.
+ The draft is validated before creation to ensure graph and features are valid.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {"file_upload": {"enabled": False}}
+
+ # Mock get_draft_workflow to return None (no existing draft)
+ # This simulates the first time a workflow is created for an app
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ with (
+ patch.object(workflow_service, "validate_features_structure"),
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_draft_workflow_was_synced"),
+ ):
+ result = workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash=None,
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # Verify workflow was added to session
+ mock_db_session.session.add.assert_called_once()
+ mock_db_session.session.commit.assert_called_once()
+
+ def test_sync_draft_workflow_updates_existing_draft(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow updates existing draft workflow.
+
+ When users edit their workflow, this updates the existing draft.
+ The unique_hash is used for optimistic locking to prevent conflicts.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {"file_upload": {"enabled": False}}
+ unique_hash = "test-hash-123"
+
+ # Mock existing draft workflow
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
+
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with (
+ patch.object(workflow_service, "validate_features_structure"),
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_draft_workflow_was_synced"),
+ ):
+ result = workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash=unique_hash,
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # Verify workflow was updated
+ assert mock_workflow.graph == json.dumps(graph)
+ assert mock_workflow.features == json.dumps(features)
+ assert mock_workflow.updated_by == account.id
+ mock_db_session.session.commit.assert_called_once()
+
+ def test_sync_draft_workflow_raises_hash_not_equal_error(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow raises error when hash doesn't match.
+
+ This implements optimistic locking: if the workflow was modified by another
+ user/session since it was loaded, the hash won't match and the update fails.
+ This prevents overwriting concurrent changes.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {}
+
+ # Mock existing draft workflow with different hash
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
+
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with pytest.raises(WorkflowHashNotEqualError):
+ workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash="new-hash",
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # ==================== Workflow Validation Tests ====================
+ # These tests verify graph structure and feature configuration validation
+
+ def test_validate_graph_structure_empty_graph(self, workflow_service):
+ """Test validate_graph_structure accepts empty graph."""
+ graph = {"nodes": []}
+
+ # Should not raise any exception
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_graph_structure_valid_graph(self, workflow_service):
+ """Test validate_graph_structure accepts valid graph."""
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+
+ # Should not raise any exception
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service):
+ """
+ Test validate_graph_structure raises error when start and trigger nodes coexist.
+
+ Workflows can be either:
+ - User-initiated (with START node): User provides input to start execution
+ - Event-driven (with trigger nodes): External events trigger execution
+
+ These two patterns cannot be mixed in a single workflow.
+ """
+ # Create a graph with both start and trigger nodes
+ # Use actual trigger node types: trigger-webhook, trigger-schedule, trigger-plugin
+ graph = {
+ "nodes": [
+ {
+ "id": "start",
+ "data": {
+ "type": "start",
+ "title": "START",
+ },
+ },
+ {
+ "id": "trigger-1",
+ "data": {
+ "type": "trigger-webhook",
+ "title": "Webhook Trigger",
+ },
+ },
+ ],
+ "edges": [],
+ }
+
+ with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_features_structure_workflow_mode(self, workflow_service):
+ """
+ Test validate_features_structure for workflow mode.
+
+ Different app modes have different feature configurations.
+ This ensures the features match the expected schema for workflow apps.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ features = {"file_upload": {"enabled": False}}
+
+ with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_validate:
+ workflow_service.validate_features_structure(app, features)
+ mock_validate.assert_called_once_with(
+ tenant_id=app.tenant_id, config=features, only_structure_validate=True
+ )
+
+ def test_validate_features_structure_advanced_chat_mode(self, workflow_service):
+ """Test validate_features_structure for advanced chat mode."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value)
+ features = {"opening_statement": "Hello"}
+
+ with patch("services.workflow_service.AdvancedChatAppConfigManager.config_validate") as mock_validate:
+ workflow_service.validate_features_structure(app, features)
+ mock_validate.assert_called_once_with(
+ tenant_id=app.tenant_id, config=features, only_structure_validate=True
+ )
+
+ def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service):
+ """Test validate_features_structure raises error for invalid mode."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION.value)
+ features = {}
+
+ with pytest.raises(ValueError, match="Invalid app mode"):
+ workflow_service.validate_features_structure(app, features)
+
+ # ==================== Publish Workflow Tests ====================
+ # These tests verify creating published versions from draft workflows
+
+ def test_publish_workflow_success(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow creates new published version.
+
+ Publishing creates a timestamped snapshot of the draft workflow.
+ This allows users to:
+ - Roll back to previous versions
+ - Use stable versions in production
+ - Continue editing draft without affecting published version
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+
+ # Mock draft workflow
+ mock_draft = TestWorkflowAssociatedDataFactory.create_workflow_mock(version=Workflow.VERSION_DRAFT, graph=graph)
+ mock_sqlalchemy_session.scalar.return_value = mock_draft
+
+ with (
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_published_workflow_was_updated"),
+ patch("services.workflow_service.dify_config") as mock_config,
+ patch("services.workflow_service.Workflow.new") as mock_workflow_new,
+ ):
+ # Disable billing
+ mock_config.BILLING_ENABLED = False
+
+ # Mock Workflow.new to return a new workflow
+ mock_new_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
+ mock_workflow_new.return_value = mock_new_workflow
+
+ result = workflow_service.publish_workflow(
+ session=mock_sqlalchemy_session,
+ app_model=app,
+ account=account,
+ marked_name="Version 1",
+ marked_comment="Initial release",
+ )
+
+ # Verify workflow was added to session
+ mock_sqlalchemy_session.add.assert_called_once_with(mock_new_workflow)
+ assert result == mock_new_workflow
+
+ def test_publish_workflow_no_draft_raises_error(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow raises error when no draft exists.
+
+ Cannot publish if there's no draft to publish from.
+ Users must create and save a draft before publishing.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+
+ # Mock no draft workflow
+ mock_sqlalchemy_session.scalar.return_value = None
+
+ with pytest.raises(ValueError, match="No valid workflow found"):
+ workflow_service.publish_workflow(session=mock_sqlalchemy_session, app_model=app, account=account)
+
+ def test_publish_workflow_trigger_limit_exceeded(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow raises error when trigger node limit exceeded in SANDBOX plan.
+
+ Free/sandbox tier users have limits on the number of trigger nodes.
+ This prevents resource abuse while allowing users to test the feature.
+ The limit is enforced at publish time, not during draft editing.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+
+ # Create graph with 3 trigger nodes (exceeds SANDBOX limit of 2)
+ # Trigger nodes enable event-driven automation which consumes resources
+ graph = {
+ "nodes": [
+ {"id": "trigger-1", "data": {"type": "trigger-webhook"}},
+ {"id": "trigger-2", "data": {"type": "trigger-schedule"}},
+ {"id": "trigger-3", "data": {"type": "trigger-plugin"}},
+ ],
+ "edges": [],
+ }
+ mock_draft = TestWorkflowAssociatedDataFactory.create_workflow_mock(version=Workflow.VERSION_DRAFT, graph=graph)
+ mock_sqlalchemy_session.scalar.return_value = mock_draft
+
+ with (
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.dify_config") as mock_config,
+ patch("services.workflow_service.BillingService") as MockBillingService,
+ patch("services.workflow_service.app_published_workflow_was_updated"),
+ ):
+ # Enable billing and set SANDBOX plan
+ mock_config.BILLING_ENABLED = True
+ MockBillingService.get_info.return_value = {"subscription": {"plan": "sandbox"}}
+
+ with pytest.raises(TriggerNodeLimitExceededError):
+ workflow_service.publish_workflow(session=mock_sqlalchemy_session, app_model=app, account=account)
+
+ # ==================== Version Management Tests ====================
+ # These tests verify listing and managing published workflow versions
+
+ def test_get_all_published_workflow_with_pagination(self, workflow_service):
+ """
+ Test get_all_published_workflow returns paginated results.
+
+ Apps can have many published versions over time.
+ Pagination prevents loading all versions at once, improving performance.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id="workflow-123")
+
+ # Mock workflows
+ mock_workflows = [
+ TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=f"workflow-{i}", version=f"v{i}")
+ for i in range(5)
+ ]
+
+ mock_session = MagicMock()
+ mock_session.scalars.return_value.all.return_value = mock_workflows
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.offset.return_value = mock_stmt
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert len(workflows) == 5
+ assert has_more is False
+
+ def test_get_all_published_workflow_has_more(self, workflow_service):
+ """
+ Test get_all_published_workflow indicates has_more when results exceed limit.
+
+ The has_more flag tells the UI whether to show a "Load More" button.
+ This is determined by fetching limit+1 records and checking if we got that many.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id="workflow-123")
+
+ # Mock 11 workflows (limit is 10, so has_more should be True)
+ mock_workflows = [
+ TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=f"workflow-{i}", version=f"v{i}")
+ for i in range(11)
+ ]
+
+ mock_session = MagicMock()
+ mock_session.scalars.return_value.all.return_value = mock_workflows
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.offset.return_value = mock_stmt
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert len(workflows) == 10
+ assert has_more is True
+
+ def test_get_all_published_workflow_no_workflow_id(self, workflow_service):
+ """Test get_all_published_workflow returns empty when app has no workflow_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
+ mock_session = MagicMock()
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert workflows == []
+ assert has_more is False
+
+ # ==================== Update Workflow Tests ====================
+ # These tests verify updating workflow metadata (name, comments, etc.)
+
+ def test_update_workflow_success(self, workflow_service):
+ """
+ Test update_workflow updates workflow attributes.
+
+ Allows updating metadata like marked_name and marked_comment
+ without creating a new version. Only specific fields are allowed
+ to prevent accidental modification of workflow logic.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ account_id = "user-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id)
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = mock_workflow
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.update_workflow(
+ session=mock_session,
+ workflow_id=workflow_id,
+ tenant_id=tenant_id,
+ account_id=account_id,
+ data={"marked_name": "Updated Name", "marked_comment": "Updated Comment"},
+ )
+
+ assert result == mock_workflow
+ assert mock_workflow.marked_name == "Updated Name"
+ assert mock_workflow.marked_comment == "Updated Comment"
+ assert mock_workflow.updated_by == account_id
+
+ def test_update_workflow_not_found(self, workflow_service):
+ """Test update_workflow returns None when workflow not found."""
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = None
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.update_workflow(
+ session=mock_session,
+ workflow_id="nonexistent",
+ tenant_id="tenant-456",
+ account_id="user-123",
+ data={"marked_name": "Test"},
+ )
+
+ assert result is None
+
+ # ==================== Delete Workflow Tests ====================
+ # These tests verify workflow deletion with safety checks
+
+ def test_delete_workflow_success(self, workflow_service):
+ """
+ Test delete_workflow successfully deletes a published workflow.
+
+ Users can delete old published versions they no longer need.
+ This helps manage storage and keeps the version list clean.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ mock_session = MagicMock()
+ # Mock successful deletion scenario:
+ # 1. Workflow exists
+ # 2. No app is currently using it
+ # 3. Not published as a tool
+ mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
+ mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.delete_workflow(
+ session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id
+ )
+
+ assert result is True
+ mock_session.delete.assert_called_once_with(mock_workflow)
+
+ def test_delete_workflow_draft_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when trying to delete draft.
+
+ Draft workflows cannot be deleted - they're the working copy.
+ Users can only delete published versions to clean up old snapshots.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(
+ workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
+ )
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = mock_workflow
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when workflow is in use by app.
+
+ Cannot delete a workflow version that's currently published/active.
+ This would break the app for users. Must publish a different version first.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+ mock_app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
+
+ mock_session = MagicMock()
+ mock_session.scalar.side_effect = [mock_workflow, mock_app]
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(WorkflowInUseError, match="currently in use by app"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_published_as_tool_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when workflow is published as tool.
+
+ Workflows can be published as reusable tools for other workflows.
+ Cannot delete a version that's being used as a tool, as this would
+ break other workflows that depend on it.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+ mock_tool_provider = MagicMock()
+
+ mock_session = MagicMock()
+ mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
+ mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(WorkflowInUseError, match="published as a tool"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_not_found_raises_error(self, workflow_service):
+ """Test delete_workflow raises error when workflow not found."""
+ workflow_id = "nonexistent"
+ tenant_id = "tenant-456"
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = None
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(ValueError, match="not found"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ # ==================== Get Default Block Config Tests ====================
+ # These tests verify retrieval of default node configurations
+
+ def test_get_default_block_configs(self, workflow_service):
+ """
+ Test get_default_block_configs returns list of default configs.
+
+ Returns default configurations for all available node types.
+ Used by the UI to populate the node palette and provide sensible defaults
+ when users add new nodes to their workflow.
+ """
+ with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping:
+ # Mock node class with default config
+ mock_node_class = MagicMock()
+ mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
+
+ mock_mapping.values.return_value = [{"latest": mock_node_class}]
+
+ with patch("services.workflow_service.LATEST_VERSION", "latest"):
+ result = workflow_service.get_default_block_configs()
+
+ assert len(result) > 0
+
+ def test_get_default_block_config_for_node_type(self, workflow_service):
+ """
+ Test get_default_block_config returns config for specific node type.
+
+ Returns the default configuration for a specific node type (e.g., LLM, HTTP).
+ This includes default values for all required and optional parameters.
+ """
+ with (
+ patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
+ patch("services.workflow_service.LATEST_VERSION", "latest"),
+ ):
+ # Mock node class with default config
+ mock_node_class = MagicMock()
+ mock_config = {"type": "llm", "config": {"provider": "openai"}}
+ mock_node_class.get_default_config.return_value = mock_config
+
+ # Create a mock mapping that includes NodeType.LLM
+ mock_mapping.__contains__.return_value = True
+ mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
+
+ result = workflow_service.get_default_block_config(NodeType.LLM.value)
+
+ assert result == mock_config
+ mock_node_class.get_default_config.assert_called_once()
+
+ def test_get_default_block_config_invalid_node_type(self, workflow_service):
+ """Test get_default_block_config returns empty dict for invalid node type."""
+ with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping:
+ # Mock mapping to not contain the node type
+ mock_mapping.__contains__.return_value = False
+
+ # Use a valid NodeType but one that's not in the mapping
+ result = workflow_service.get_default_block_config(NodeType.LLM.value)
+
+ assert result == {}
+
+ # ==================== Workflow Conversion Tests ====================
+ # These tests verify converting basic apps to workflow apps
+
+ def test_convert_to_workflow_from_chat_app(self, workflow_service):
+ """
+ Test convert_to_workflow converts chat app to workflow.
+
+ Allows users to migrate from simple chat apps to advanced workflow apps.
+ The conversion creates equivalent workflow nodes from the chat configuration,
+ giving users more control and customization options.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.CHAT.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {
+ "name": "Converted Workflow",
+ "icon_type": "emoji",
+ "icon": "🤖",
+ "icon_background": "#FFEAD5",
+ }
+
+ with patch("services.workflow_service.WorkflowConverter") as MockConverter:
+ mock_converter = MockConverter.return_value
+ mock_new_app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ mock_converter.convert_to_workflow.return_value = mock_new_app
+
+ result = workflow_service.convert_to_workflow(app, account, args)
+
+ assert result == mock_new_app
+ mock_converter.convert_to_workflow.assert_called_once()
+
+ def test_convert_to_workflow_from_completion_app(self, workflow_service):
+ """
+ Test convert_to_workflow converts completion app to workflow.
+
+ Similar to chat conversion, but for completion-style apps.
+ Completion apps are simpler (single prompt-response), so the
+ conversion creates a basic workflow with fewer nodes.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {"name": "Converted Workflow"}
+
+ with patch("services.workflow_service.WorkflowConverter") as MockConverter:
+ mock_converter = MockConverter.return_value
+ mock_new_app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ mock_converter.convert_to_workflow.return_value = mock_new_app
+
+ result = workflow_service.convert_to_workflow(app, account, args)
+
+ assert result == mock_new_app
+
+ def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service):
+ """
+ Test convert_to_workflow raises error for invalid app mode.
+
+ Only chat and completion apps can be converted to workflows.
+ Apps that are already workflows or have other modes cannot be converted.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {}
+
+ with pytest.raises(ValueError, match="not supported convert to workflow"):
+ workflow_service.convert_to_workflow(app, account, args)
From 7a7fea40d9eb5f15f18d8fd55f6ef8dc9166e1bf Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Thu, 27 Nov 2025 01:39:33 -0500
Subject: [PATCH 54/63] feat: complete test script of dataset retrieval
(#28762)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/core/rag/retrieval/__init__.py | 0
.../rag/retrieval/test_dataset_retrieval.py | 1696 +++++++++++++++++
2 files changed, 1696 insertions(+)
create mode 100644 api/tests/unit_tests/core/rag/retrieval/__init__.py
create mode 100644 api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
diff --git a/api/tests/unit_tests/core/rag/retrieval/__init__.py b/api/tests/unit_tests/core/rag/retrieval/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
new file mode 100644
index 0000000000..0163e42992
--- /dev/null
+++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
@@ -0,0 +1,1696 @@
+"""
+Unit tests for dataset retrieval functionality.
+
+This module provides comprehensive test coverage for the RetrievalService class,
+which is responsible for retrieving relevant documents from datasets using various
+search strategies.
+
+Core Retrieval Mechanisms Tested:
+==================================
+1. **Vector Search (Semantic Search)**
+ - Uses embedding vectors to find semantically similar documents
+ - Supports score thresholds and top-k limiting
+ - Can filter by document IDs and metadata
+
+2. **Keyword Search**
+ - Traditional text-based search using keyword matching
+ - Handles special characters and query escaping
+ - Supports document filtering
+
+3. **Full-Text Search**
+ - BM25-based full-text search for text matching
+ - Used in hybrid search scenarios
+
+4. **Hybrid Search**
+ - Combines vector and full-text search results
+ - Implements deduplication to avoid duplicate chunks
+ - Uses DataPostProcessor for score merging with configurable weights
+
+5. **Score Merging Algorithms**
+ - Deduplication based on doc_id
+ - Retains higher-scoring duplicates
+ - Supports weighted score combination
+
+6. **Metadata Filtering**
+ - Filters documents based on metadata conditions
+ - Supports document ID filtering
+
+Test Architecture:
+==================
+- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app)
+- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.)
+ rather than at the class level to properly simulate the ThreadPoolExecutor behavior
+- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern
+- **Isolation**: Each test is independent and doesn't rely on external state
+
+Running Tests:
+==============
+ # Run all tests in this module
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v
+
+ # Run a specific test class
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v
+
+ # Run a specific test
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\
+TestRetrievalService::test_vector_search_basic -v
+
+Notes:
+======
+- The RetrievalService uses ThreadPoolExecutor for concurrent search operations
+- Tests mock the individual search methods to avoid threading complexity
+- All mocked search methods modify the all_documents list in-place
+- Score thresholds and top-k limits are enforced by the search methods
+"""
+
+from unittest.mock import MagicMock, Mock, patch
+from uuid import uuid4
+
+import pytest
+
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from models.dataset import Dataset
+
+# ==================== Helper Functions ====================
+
+
+def create_mock_document(
+ content: str,
+ doc_id: str,
+ score: float = 0.8,
+ provider: str = "dify",
+ additional_metadata: dict | None = None,
+) -> Document:
+ """
+ Create a mock Document object for testing.
+
+ This helper function standardizes document creation across tests,
+ ensuring consistent structure and reducing code duplication.
+
+ Args:
+ content: The text content of the document
+ doc_id: Unique identifier for the document chunk
+ score: Relevance score (0.0 to 1.0)
+ provider: Document provider ("dify" or "external")
+ additional_metadata: Optional extra metadata fields
+
+ Returns:
+ Document: A properly structured Document object
+
+ Example:
+ >>> doc = create_mock_document("Python is great", "doc1", score=0.95)
+ >>> assert doc.metadata["score"] == 0.95
+ """
+ metadata = {
+ "doc_id": doc_id,
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": score,
+ }
+
+ # Merge additional metadata if provided
+ if additional_metadata:
+ metadata.update(additional_metadata)
+
+ return Document(
+ page_content=content,
+ metadata=metadata,
+ provider=provider,
+ )
+
+
+def create_side_effect_for_search(documents: list[Document]):
+ """
+ Create a side effect function for mocking search methods.
+
+ This helper creates a function that simulates how RetrievalService
+ search methods work - they modify the all_documents list in-place
+ rather than returning values directly.
+
+ Args:
+ documents: List of documents to add to all_documents
+
+ Returns:
+ Callable: A side effect function compatible with mock.side_effect
+
+ Example:
+ >>> mock_search.side_effect = create_side_effect_for_search([doc1, doc2])
+
+ Note:
+ The RetrievalService uses ThreadPoolExecutor which submits tasks that
+ modify a shared all_documents list. This pattern simulates that behavior.
+ """
+
+ def side_effect(flask_app, dataset_id, query, top_k, *args, all_documents, exceptions, **kwargs):
+ """
+ Side effect function that mimics search method behavior.
+
+ Args:
+ flask_app: Flask application context (unused in mock)
+ dataset_id: ID of the dataset being searched
+ query: Search query string
+ top_k: Maximum number of results
+ all_documents: Shared list to append results to
+ exceptions: Shared list to append errors to
+ **kwargs: Additional arguments (score_threshold, document_ids_filter, etc.)
+ """
+ all_documents.extend(documents)
+
+ return side_effect
+
+
+def create_side_effect_with_exception(error_message: str):
+ """
+ Create a side effect function that adds an exception to the exceptions list.
+
+ Used for testing error handling in the RetrievalService.
+
+ Args:
+ error_message: The error message to add to exceptions
+
+ Returns:
+ Callable: A side effect function that simulates an error
+
+ Example:
+ >>> mock_search.side_effect = create_side_effect_with_exception("Search failed")
+ """
+
+ def side_effect(flask_app, dataset_id, query, top_k, *args, all_documents, exceptions, **kwargs):
+ """Add error message to exceptions list."""
+ exceptions.append(error_message)
+
+ return side_effect
+
+
+class TestRetrievalService:
+ """
+ Comprehensive test suite for RetrievalService class.
+
+ This test class validates all retrieval methods and their interactions,
+ including edge cases, error handling, and integration scenarios.
+
+ Test Organization:
+ ==================
+ 1. Fixtures (lines ~190-240)
+ - mock_dataset: Standard dataset configuration
+ - sample_documents: Reusable test documents with varying scores
+ - mock_flask_app: Flask application context
+ - mock_thread_pool: Synchronous executor for deterministic testing
+
+ 2. Vector Search Tests (lines ~240-350)
+ - Basic functionality
+ - Document filtering
+ - Empty results
+ - Metadata filtering
+ - Score thresholds
+
+ 3. Keyword Search Tests (lines ~350-450)
+ - Basic keyword matching
+ - Special character handling
+ - Document filtering
+
+ 4. Hybrid Search Tests (lines ~450-640)
+ - Vector + full-text combination
+ - Deduplication logic
+ - Weighted score merging
+
+ 5. Full-Text Search Tests (lines ~640-680)
+ - BM25-based search
+
+ 6. Score Merging Tests (lines ~680-790)
+ - Deduplication algorithms
+ - Score comparison
+ - Provider-specific handling
+
+ 7. Error Handling Tests (lines ~790-920)
+ - Empty queries
+ - Non-existent datasets
+ - Exception propagation
+
+ 8. Additional Tests (lines ~920-1080)
+ - Query escaping
+ - Reranking integration
+ - Top-K limiting
+
+ Mocking Strategy:
+ =================
+ Tests mock at the method level (embedding_search, keyword_search, etc.)
+ rather than the underlying Vector/Keyword classes. This approach:
+ - Avoids complexity of mocking ThreadPoolExecutor behavior
+ - Provides clearer test intent
+ - Makes tests more maintainable
+ - Properly simulates the in-place list modification pattern
+
+ Common Patterns:
+ ================
+ 1. **Arrange**: Set up mocks with side_effect functions
+ 2. **Act**: Call RetrievalService.retrieve() with specific parameters
+ 3. **Assert**: Verify results, mock calls, and side effects
+
+ Example Test Structure:
+ ```python
+ def test_example(self, mock_get_dataset, mock_search, mock_dataset):
+ # Arrange: Set up test data and mocks
+ mock_get_dataset.return_value = mock_dataset
+ mock_search.side_effect = create_side_effect_for_search([doc1, doc2])
+
+ # Act: Execute the method under test
+ results = RetrievalService.retrieve(...)
+
+ # Assert: Verify expectations
+ assert len(results) == 2
+ mock_search.assert_called_once()
+ ```
+ """
+
+ @pytest.fixture
+ def mock_dataset(self) -> Dataset:
+ """
+ Create a mock Dataset object for testing.
+
+ Returns:
+ Dataset: Mock dataset with standard configuration
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = str(uuid4())
+ dataset.tenant_id = str(uuid4())
+ dataset.name = "test_dataset"
+ dataset.indexing_technique = "high_quality"
+ dataset.embedding_model = "text-embedding-ada-002"
+ dataset.embedding_model_provider = "openai"
+ dataset.retrieval_model = {
+ "search_method": RetrievalMethod.SEMANTIC_SEARCH,
+ "reranking_enable": False,
+ "top_k": 4,
+ "score_threshold_enabled": False,
+ }
+ return dataset
+
+ @pytest.fixture
+ def sample_documents(self) -> list[Document]:
+ """
+ Create sample documents for testing retrieval results.
+
+ Returns:
+ list[Document]: List of mock documents with varying scores
+ """
+ return [
+ Document(
+ page_content="Python is a high-level programming language.",
+ metadata={
+ "doc_id": "doc1",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.95,
+ },
+ provider="dify",
+ ),
+ Document(
+ page_content="JavaScript is widely used for web development.",
+ metadata={
+ "doc_id": "doc2",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.85,
+ },
+ provider="dify",
+ ),
+ Document(
+ page_content="Machine learning is a subset of artificial intelligence.",
+ metadata={
+ "doc_id": "doc3",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.75,
+ },
+ provider="dify",
+ ),
+ ]
+
+ @pytest.fixture
+ def mock_flask_app(self):
+ """
+ Create a mock Flask application context.
+
+ Returns:
+ Mock: Flask app mock with app_context
+ """
+ app = MagicMock()
+ app.app_context.return_value.__enter__ = Mock()
+ app.app_context.return_value.__exit__ = Mock()
+ return app
+
+ @pytest.fixture(autouse=True)
+ def mock_thread_pool(self):
+ """
+ Mock ThreadPoolExecutor to run tasks synchronously in tests.
+
+ The RetrievalService uses ThreadPoolExecutor to run search operations
+ concurrently (embedding_search, keyword_search, full_text_index_search).
+ In tests, we want synchronous execution for:
+ - Deterministic behavior
+ - Easier debugging
+ - Avoiding race conditions
+ - Simpler assertions
+
+ How it works:
+ -------------
+ 1. Intercepts ThreadPoolExecutor creation
+ 2. Replaces submit() to execute functions immediately (synchronously)
+ 3. Functions modify shared all_documents list in-place
+ 4. Mocks concurrent.futures.wait() since tasks are already done
+
+ Why this approach:
+ ------------------
+ - RetrievalService.retrieve() creates a ThreadPoolExecutor context
+ - It submits search tasks that modify all_documents list
+ - concurrent.futures.wait() waits for all tasks to complete
+ - By executing synchronously, we avoid threading complexity in tests
+
+ Returns:
+ Mock: Mocked ThreadPoolExecutor that executes tasks synchronously
+ """
+ with patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor:
+ # Store futures to track submitted tasks (for debugging if needed)
+ futures_list = []
+
+ def sync_submit(fn, *args, **kwargs):
+ """
+ Synchronous replacement for ThreadPoolExecutor.submit().
+
+ Instead of scheduling the function for async execution,
+ we execute it immediately in the current thread.
+
+ Args:
+ fn: The function to execute (e.g., embedding_search)
+ *args, **kwargs: Arguments to pass to the function
+
+ Returns:
+ Mock: A mock Future object
+ """
+ future = Mock()
+ try:
+ # Execute immediately - this modifies all_documents in place
+ # The function signature is: fn(flask_app, dataset_id, query,
+ # top_k, all_documents, exceptions, ...)
+ fn(*args, **kwargs)
+ future.result.return_value = None
+ future.exception.return_value = None
+ except Exception as e:
+ # If function raises, store exception in future
+ future.result.return_value = None
+ future.exception.return_value = e
+
+ futures_list.append(future)
+ return future
+
+ # Set up the mock executor instance
+ mock_executor_instance = Mock()
+ mock_executor_instance.submit = sync_submit
+
+ # Configure context manager behavior (__enter__ and __exit__)
+ mock_executor.return_value.__enter__.return_value = mock_executor_instance
+ mock_executor.return_value.__exit__.return_value = None
+
+ # Mock concurrent.futures.wait to do nothing since tasks are already done
+ # In real code, this waits for all futures to complete
+ # In tests, futures complete immediately, so wait is a no-op
+ with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
+ yield mock_executor
+
+ # ==================== Vector Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
+ """
+ Test basic vector/semantic search functionality.
+
+ This test validates the core vector search flow:
+ 1. Dataset is retrieved from database
+ 2. embedding_search is called via ThreadPoolExecutor
+ 3. Documents are added to shared all_documents list
+ 4. Results are returned to caller
+
+ Verifies:
+ - Vector search is called with correct parameters
+ - Results are returned in expected format
+ - Score threshold is applied correctly
+ - Documents maintain their metadata and scores
+ """
+ # ==================== ARRANGE ====================
+ # Set up the mock dataset that will be "retrieved" from database
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create a side effect function that simulates embedding_search behavior
+ # In the real implementation, embedding_search:
+ # 1. Gets the dataset
+ # 2. Creates a Vector instance
+ # 3. Calls search_by_vector with embeddings
+ # 4. Extends all_documents with results
+ def side_effect_embedding_search(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Simulate embedding_search adding documents to the shared list."""
+ all_documents.extend(sample_documents)
+
+ mock_embedding_search.side_effect = side_effect_embedding_search
+
+ # Define test parameters
+ query = "What is Python?" # Natural language query
+ top_k = 3 # Maximum number of results to return
+ score_threshold = 0.7 # Minimum relevance score (0.0 to 1.0)
+
+ # ==================== ACT ====================
+ # Call the retrieve method with SEMANTIC_SEARCH strategy
+ # This will:
+ # 1. Check if query is empty (early return if so)
+ # 2. Get the dataset using _get_dataset
+ # 3. Create ThreadPoolExecutor
+ # 4. Submit embedding_search task
+ # 5. Wait for completion
+ # 6. Return all_documents list
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=top_k,
+ score_threshold=score_threshold,
+ )
+
+ # ==================== ASSERT ====================
+ # Verify we got the expected number of documents
+ assert len(results) == 3, "Should return 3 documents from sample_documents"
+
+ # Verify all results are Document objects (type safety)
+ assert all(isinstance(doc, Document) for doc in results), "All results should be Document instances"
+
+ # Verify documents maintain their scores (highest score first in sample_documents)
+ assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
+
+ # Verify embedding_search was called exactly once
+ # This confirms the search method was invoked by ThreadPoolExecutor
+ mock_embedding_search.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_document_filter(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test vector search with document ID filtering.
+
+ Verifies:
+ - Document ID filter is passed correctly to vector search
+ - Only specified documents are searched
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ filtered_docs = [sample_documents[0]]
+
+ def side_effect_embedding_search(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(filtered_docs)
+
+ mock_embedding_search.side_effect = side_effect_embedding_search
+ document_ids_filter = [sample_documents[0].metadata["document_id"]]
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ document_ids_filter=document_ids_filter,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["doc_id"] == "doc1"
+ # Verify document_ids_filter was passed
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["document_ids_filter"] == document_ids_filter
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test vector search when no results match the query.
+
+ Verifies:
+ - Empty list is returned when no documents match
+ - No errors are raised
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ # embedding_search doesn't add anything to all_documents
+ mock_embedding_search.side_effect = lambda *args, **kwargs: None
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="nonexistent query",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ # ==================== Keyword Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
+ """
+ Test basic keyword search functionality.
+
+ Verifies:
+ - Keyword search is invoked correctly
+ - Query is escaped properly for search
+ - Results are returned in expected format
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_keyword_search(
+ flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
+ ):
+ all_documents.extend(sample_documents)
+
+ mock_keyword_search.side_effect = side_effect_keyword_search
+
+ query = "Python programming"
+ top_k = 3
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=top_k,
+ )
+
+ # Assert
+ assert len(results) == 3
+ assert all(isinstance(doc, Document) for doc in results)
+ mock_keyword_search.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_with_special_characters(self, mock_get_dataset, mock_keyword_search, mock_dataset):
+ """
+ Test keyword search with special characters in query.
+
+ Verifies:
+ - Special characters are escaped correctly
+ - Search handles quotes and other special chars
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ mock_keyword_search.side_effect = lambda *args, **kwargs: None
+
+ query = 'Python "programming" language'
+
+ # Act
+ RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=5,
+ )
+
+ # Assert
+ # Verify that keyword_search was called
+ assert mock_keyword_search.called
+ # The query escaping happens inside keyword_search method
+ call_args = mock_keyword_search.call_args
+ assert call_args is not None
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_with_document_filter(
+ self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents
+ ):
+ """
+ Test keyword search with document ID filtering.
+
+ Verifies:
+ - Document filter is applied to keyword search
+ - Only filtered documents are returned
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ filtered_docs = [sample_documents[1]]
+
+ def side_effect_keyword_search(
+ flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
+ ):
+ all_documents.extend(filtered_docs)
+
+ mock_keyword_search.side_effect = side_effect_keyword_search
+ document_ids_filter = [sample_documents[1].metadata["document_id"]]
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="JavaScript",
+ top_k=5,
+ document_ids_filter=document_ids_filter,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["doc_id"] == "doc2"
+
+ # ==================== Hybrid Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_basic(
+ self,
+ mock_get_dataset,
+ mock_embedding_search,
+ mock_fulltext_search,
+ mock_data_processor_class,
+ mock_dataset,
+ sample_documents,
+ ):
+ """
+ Test basic hybrid search combining vector and full-text search.
+
+ Verifies:
+ - Both vector and full-text search are executed
+ - Results are merged and deduplicated
+ - DataPostProcessor is invoked for score merging
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Vector search returns first 2 docs
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[:2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Full-text search returns last 2 docs (with overlap)
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[1:])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Mock DataPostProcessor
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = sample_documents
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="Python programming",
+ top_k=3,
+ score_threshold=0.5,
+ )
+
+ # Assert
+ assert len(results) == 3
+ mock_embedding_search.assert_called_once()
+ mock_fulltext_search.assert_called_once()
+ mock_processor_instance.invoke.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_deduplication(
+ self, mock_get_dataset, mock_embedding_search, mock_fulltext_search, mock_data_processor_class, mock_dataset
+ ):
+ """
+ Test that hybrid search properly deduplicates documents.
+
+ Hybrid search combines results from multiple search methods (vector + full-text).
+ This can lead to duplicate documents when the same chunk is found by both methods.
+
+ Scenario:
+ ---------
+ 1. Vector search finds document "duplicate_doc" with score 0.9
+ 2. Full-text search also finds "duplicate_doc" but with score 0.6
+ 3. Both searches find "unique_doc"
+ 4. Deduplication should keep only the higher-scoring version (0.9)
+
+ Why deduplication matters:
+ --------------------------
+ - Prevents showing the same content multiple times to users
+ - Ensures score consistency (keeps best match)
+ - Improves result quality and user experience
+ - Happens BEFORE reranking to avoid processing duplicates
+
+ Verifies:
+ - Duplicate documents (same doc_id) are removed
+ - Higher scoring duplicate is retained
+ - Deduplication happens before post-processing
+ - Final result count is correct
+ """
+ # ==================== ARRANGE ====================
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create test documents with intentional duplication
+ # Same doc_id but different scores to test score comparison logic
+ doc1_high = Document(
+ page_content="Content 1",
+ metadata={
+ "doc_id": "duplicate_doc", # Same doc_id as doc1_low
+ "score": 0.9, # Higher score - should be kept
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+ doc1_low = Document(
+ page_content="Content 1",
+ metadata={
+ "doc_id": "duplicate_doc", # Same doc_id as doc1_high
+ "score": 0.6, # Lower score - should be discarded
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+ doc2 = Document(
+ page_content="Content 2",
+ metadata={
+ "doc_id": "unique_doc", # Unique doc_id
+ "score": 0.8,
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+
+ # Simulate vector search returning high-score duplicate + unique doc
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Vector search finds 2 documents including high-score duplicate."""
+ all_documents.extend([doc1_high, doc2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Simulate full-text search returning low-score duplicate
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Full-text search finds the same document but with lower score."""
+ all_documents.extend([doc1_low])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Mock DataPostProcessor to return deduplicated results
+ # In real implementation, _deduplicate_documents is called before this
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = [doc1_high, doc2]
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ # ==================== ACT ====================
+ # Execute hybrid search which should:
+ # 1. Run both embedding_search and full_text_index_search
+ # 2. Collect all results in all_documents (3 docs: 2 unique + 1 duplicate)
+ # 3. Call _deduplicate_documents to remove duplicate (keeps higher score)
+ # 4. Pass deduplicated results to DataPostProcessor
+ # 5. Return final results
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test",
+ top_k=5,
+ )
+
+ # ==================== ASSERT ====================
+ # Verify deduplication worked correctly
+ assert len(results) == 2, "Should have 2 unique documents after deduplication (not 3)"
+
+ # Verify the correct documents are present
+ doc_ids = [doc.metadata["doc_id"] for doc in results]
+ assert "duplicate_doc" in doc_ids, "Duplicate doc should be present (higher score version)"
+ assert "unique_doc" in doc_ids, "Unique doc should be present"
+
+ # Implicitly verifies that doc1_low (score 0.6) was discarded
+ # in favor of doc1_high (score 0.9)
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_with_weights(
+ self,
+ mock_get_dataset,
+ mock_embedding_search,
+ mock_fulltext_search,
+ mock_data_processor_class,
+ mock_dataset,
+ sample_documents,
+ ):
+ """
+ Test hybrid search with custom weights for score merging.
+
+ Verifies:
+ - Weights are passed to DataPostProcessor
+ - Score merging respects weight configuration
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[:2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[1:])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = sample_documents
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ weights = {
+ "vector_setting": {
+ "vector_weight": 0.7,
+ "embedding_provider_name": "openai",
+ "embedding_model_name": "text-embedding-ada-002",
+ },
+ "keyword_setting": {"keyword_weight": 0.3},
+ }
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=3,
+ weights=weights,
+ reranking_mode="weighted_score",
+ )
+
+ # Assert
+ assert len(results) == 3
+ # Verify DataPostProcessor was created with weights
+ mock_data_processor_class.assert_called_once()
+ # Check that weights were passed (may be in args or kwargs)
+ call_args = mock_data_processor_class.call_args
+ if call_args.kwargs:
+ assert call_args.kwargs.get("weights") == weights
+ else:
+ # Weights might be in positional args (position 3)
+ assert len(call_args.args) >= 4
+
+ # ==================== Full-Text Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_fulltext_search_basic(self, mock_get_dataset, mock_fulltext_search, mock_dataset, sample_documents):
+ """
+ Test basic full-text search functionality.
+
+ Verifies:
+ - Full-text search is invoked correctly
+ - Results are returned in expected format
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents)
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="programming language",
+ top_k=3,
+ )
+
+ # Assert
+ assert len(results) == 3
+ mock_fulltext_search.assert_called_once()
+
+ # ==================== Score Merging Tests ====================
+
+ def test_deduplicate_documents_basic(self):
+ """
+ Test basic document deduplication logic.
+
+ Verifies:
+ - Documents with same doc_id are deduplicated
+ - First occurrence is kept by default
+ """
+ # Arrange
+ doc1 = Document(
+ page_content="Content 1",
+ metadata={"doc_id": "doc1", "score": 0.8},
+ provider="dify",
+ )
+ doc2 = Document(
+ page_content="Content 2",
+ metadata={"doc_id": "doc2", "score": 0.7},
+ provider="dify",
+ )
+ doc1_duplicate = Document(
+ page_content="Content 1 duplicate",
+ metadata={"doc_id": "doc1", "score": 0.6},
+ provider="dify",
+ )
+
+ documents = [doc1, doc2, doc1_duplicate]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ assert len(result) == 2
+ doc_ids = [doc.metadata["doc_id"] for doc in result]
+ assert doc_ids == ["doc1", "doc2"]
+
+ def test_deduplicate_documents_keeps_higher_score(self):
+ """
+ Test that deduplication keeps document with higher score.
+
+ Verifies:
+ - When duplicates exist, higher scoring version is retained
+ - Score comparison works correctly
+ """
+ # Arrange
+ doc_low = Document(
+ page_content="Content",
+ metadata={"doc_id": "doc1", "score": 0.5},
+ provider="dify",
+ )
+ doc_high = Document(
+ page_content="Content",
+ metadata={"doc_id": "doc1", "score": 0.9},
+ provider="dify",
+ )
+
+ # Low score first
+ documents = [doc_low, doc_high]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ assert len(result) == 1
+ assert result[0].metadata["score"] == 0.9
+
+ def test_deduplicate_documents_empty_list(self):
+ """
+ Test deduplication with empty document list.
+
+ Verifies:
+ - Empty list returns empty list
+ - No errors are raised
+ """
+ # Act
+ result = RetrievalService._deduplicate_documents([])
+
+ # Assert
+ assert result == []
+
+ def test_deduplicate_documents_non_dify_provider(self):
+ """
+ Test deduplication with non-dify provider documents.
+
+ Verifies:
+ - External provider documents use content-based deduplication
+ - Different providers are handled correctly
+ """
+ # Arrange
+ doc1 = Document(
+ page_content="External content",
+ metadata={"score": 0.8},
+ provider="external",
+ )
+ doc2 = Document(
+ page_content="External content",
+ metadata={"score": 0.7},
+ provider="external",
+ )
+
+ documents = [doc1, doc2]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ # External documents without doc_id should use content-based dedup
+ assert len(result) >= 1
+
+ # ==================== Metadata Filtering Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_metadata_filter(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test vector search with metadata-based document filtering.
+
+ Verifies:
+ - Metadata filters are applied correctly
+ - Only documents matching metadata criteria are returned
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Add metadata to documents
+ filtered_doc = sample_documents[0]
+ filtered_doc.metadata["category"] = "programming"
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.append(filtered_doc)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="Python",
+ top_k=5,
+ document_ids_filter=[filtered_doc.metadata["document_id"]],
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata.get("category") == "programming"
+
+ # ==================== Error Handling Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_empty_query(self, mock_get_dataset, mock_dataset):
+ """
+ Test retrieval with empty query string.
+
+ Verifies:
+ - Empty query returns empty results
+ - No search operations are performed
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_nonexistent_dataset(self, mock_get_dataset):
+ """
+ Test retrieval with non-existent dataset ID.
+
+ Verifies:
+ - Non-existent dataset returns empty results
+ - No errors are raised
+ """
+ # Arrange
+ mock_get_dataset.return_value = None
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id="nonexistent_id",
+ query="test query",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test that exceptions during retrieval are properly handled.
+
+ Verifies:
+ - Exceptions are caught and added to exceptions list
+ - ValueError is raised with exception messages
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Make embedding_search add an exception to the exceptions list
+ def side_effect_with_exception(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ exceptions.append("Search failed")
+
+ mock_embedding_search.side_effect = side_effect_with_exception
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ )
+
+ assert "Search failed" in str(exc_info.value)
+
+ # ==================== Score Threshold Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test vector search with score threshold filtering.
+
+ Verifies:
+ - Score threshold is passed to search method
+ - Documents below threshold are filtered out
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Only return documents above threshold
+ high_score_doc = Document(
+ page_content="High relevance content",
+ metadata={"doc_id": "doc1", "score": 0.85},
+ provider="dify",
+ )
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.append(high_score_doc)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ score_threshold = 0.8
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ score_threshold=score_threshold,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["score"] >= score_threshold
+
+ # ==================== Top-K Limiting Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test that retrieval respects top_k parameter.
+
+ Verifies:
+ - Only top_k documents are returned
+ - Limit is applied correctly
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create more documents than top_k
+ many_docs = [
+ Document(
+ page_content=f"Content {i}",
+ metadata={"doc_id": f"doc{i}", "score": 0.9 - i * 0.1},
+ provider="dify",
+ )
+ for i in range(10)
+ ]
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ # Return only top_k documents
+ all_documents.extend(many_docs[:top_k])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ top_k = 3
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=top_k,
+ )
+
+ # Assert
+ # Verify top_k was passed to embedding_search
+ assert mock_embedding_search.called
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["top_k"] == top_k
+ # Verify we got the right number of results
+ assert len(results) == top_k
+
+ # ==================== Query Escaping Tests ====================
+
+ def test_escape_query_for_search(self):
+ """
+ Test query escaping for special characters.
+
+ Verifies:
+ - Double quotes are properly escaped
+ - Other characters remain unchanged
+ """
+ # Test cases with expected outputs
+ test_cases = [
+ ("simple query", "simple query"),
+ ('query with "quotes"', 'query with \\"quotes\\"'),
+ ('"quoted phrase"', '\\"quoted phrase\\"'),
+ ("no special chars", "no special chars"),
+ ]
+
+ for input_query, expected_output in test_cases:
+ result = RetrievalService.escape_query_for_search(input_query)
+ assert result == expected_output
+
+ # ==================== Reranking Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_semantic_search_with_reranking(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test semantic search with reranking model.
+
+ Verifies:
+ - Reranking is applied when configured
+ - DataPostProcessor is invoked with correct parameters
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Simulate reranking changing order
+ reranked_docs = list(reversed(sample_documents))
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ # embedding_search handles reranking internally
+ all_documents.extend(reranked_docs)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ reranking_model = {
+ "reranking_provider_name": "cohere",
+ "reranking_model_name": "rerank-english-v2.0",
+ }
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=3,
+ reranking_model=reranking_model,
+ )
+
+ # Assert
+ # For semantic search with reranking, reranking_model should be passed
+ assert len(results) == 3
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["reranking_model"] == reranking_model
+
+
+class TestRetrievalMethods:
+ """
+ Test suite for RetrievalMethod enum and utility methods.
+
+ The RetrievalMethod enum defines the available search strategies:
+
+ 1. **SEMANTIC_SEARCH**: Vector-based similarity search using embeddings
+ - Best for: Natural language queries, conceptual similarity
+ - Uses: Embedding models (e.g., text-embedding-ada-002)
+ - Example: "What is machine learning?" matches "AI and ML concepts"
+
+ 2. **FULL_TEXT_SEARCH**: BM25-based text matching
+ - Best for: Exact phrase matching, keyword presence
+ - Uses: BM25 algorithm with sparse vectors
+ - Example: "Python programming" matches documents with those exact terms
+
+ 3. **HYBRID_SEARCH**: Combination of semantic + full-text
+ - Best for: Comprehensive search with both conceptual and exact matching
+ - Uses: Both embedding vectors and BM25, with score merging
+ - Example: Finds both semantically similar and keyword-matching documents
+
+ 4. **KEYWORD_SEARCH**: Traditional keyword-based search (economy mode)
+ - Best for: Simple, fast searches without embeddings
+ - Uses: Jieba tokenization and keyword matching
+ - Example: Basic text search without vector database
+
+ Utility Methods:
+ ================
+ - is_support_semantic_search(): Check if method uses embeddings
+ - is_support_fulltext_search(): Check if method uses BM25
+
+ These utilities help determine which search operations to execute
+ in the RetrievalService.retrieve() method.
+ """
+
+ def test_retrieval_method_values(self):
+ """
+ Test that all retrieval method constants are defined correctly.
+
+ This ensures the enum values match the expected string constants
+ used throughout the codebase for configuration and API calls.
+
+ Verifies:
+ - All expected retrieval methods exist
+ - Values are correct strings (not accidentally changed)
+ - String values match database/config expectations
+ """
+ assert RetrievalMethod.SEMANTIC_SEARCH == "semantic_search"
+ assert RetrievalMethod.FULL_TEXT_SEARCH == "full_text_search"
+ assert RetrievalMethod.HYBRID_SEARCH == "hybrid_search"
+ assert RetrievalMethod.KEYWORD_SEARCH == "keyword_search"
+
+ def test_is_support_semantic_search(self):
+ """
+ Test semantic search support detection.
+
+ Verifies:
+ - Semantic search method is detected
+ - Hybrid search method is detected (includes semantic)
+ - Other methods are not detected
+ """
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.SEMANTIC_SEARCH) is True
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.HYBRID_SEARCH) is True
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.FULL_TEXT_SEARCH) is False
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.KEYWORD_SEARCH) is False
+
+ def test_is_support_fulltext_search(self):
+ """
+ Test full-text search support detection.
+
+ Verifies:
+ - Full-text search method is detected
+ - Hybrid search method is detected (includes full-text)
+ - Other methods are not detected
+ """
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.FULL_TEXT_SEARCH) is True
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.HYBRID_SEARCH) is True
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.SEMANTIC_SEARCH) is False
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.KEYWORD_SEARCH) is False
+
+
+class TestDocumentModel:
+ """
+ Test suite for Document model used in retrieval.
+
+ The Document class is the core data structure for representing text chunks
+ in the retrieval system. It's based on Pydantic BaseModel for validation.
+
+ Document Structure:
+ ===================
+ - **page_content** (str): The actual text content of the document chunk
+ - **metadata** (dict): Additional information about the document
+ - doc_id: Unique identifier for the chunk
+ - document_id: Parent document ID
+ - dataset_id: Dataset this document belongs to
+ - score: Relevance score from search (0.0 to 1.0)
+ - Custom fields: category, tags, timestamps, etc.
+ - **provider** (str): Source of the document ("dify" or "external")
+ - **vector** (list[float] | None): Embedding vector for semantic search
+ - **children** (list[ChildDocument] | None): Sub-chunks for hierarchical docs
+
+ Document Lifecycle:
+ ===================
+ 1. **Creation**: Documents are created when text is indexed
+ - Content is chunked into manageable pieces
+ - Embeddings are generated for semantic search
+ - Metadata is attached for filtering and tracking
+
+ 2. **Storage**: Documents are stored in vector databases
+ - Vector field stores embeddings
+ - Metadata enables filtering
+ - Provider tracks source (internal vs external)
+
+ 3. **Retrieval**: Documents are returned from search operations
+ - Scores are added during search
+ - Multiple documents may be combined (hybrid search)
+ - Deduplication uses doc_id
+
+ 4. **Post-processing**: Documents may be reranked or filtered
+ - Scores can be recalculated
+ - Content may be truncated or formatted
+ - Metadata is used for display
+
+ Why Test the Document Model:
+ ============================
+ - Ensures data structure integrity
+ - Validates Pydantic model behavior
+ - Confirms default values work correctly
+ - Tests equality comparison for deduplication
+ - Verifies metadata handling
+
+ Related Classes:
+ ================
+ - ChildDocument: For hierarchical document structures
+ - RetrievalSegments: Combines Document with database segment info
+ """
+
+ def test_document_creation_basic(self):
+ """
+ Test basic Document object creation.
+
+ Tests the minimal required fields and default values.
+ Only page_content is required; all other fields have defaults.
+
+ Verifies:
+ - Document can be created with minimal fields
+ - Default values are set correctly
+ - Pydantic validation works
+ - No exceptions are raised
+ """
+ doc = Document(page_content="Test content")
+
+ assert doc.page_content == "Test content"
+ assert doc.metadata == {} # Empty dict by default
+ assert doc.provider == "dify" # Default provider
+ assert doc.vector is None # No embedding by default
+ assert doc.children is None # No child documents by default
+
+ def test_document_creation_with_metadata(self):
+ """
+ Test Document creation with metadata.
+
+ Verifies:
+ - Metadata is stored correctly
+ - Metadata can contain various types
+ """
+ metadata = {
+ "doc_id": "test_doc",
+ "score": 0.95,
+ "dataset_id": str(uuid4()),
+ "category": "test",
+ }
+ doc = Document(page_content="Test content", metadata=metadata)
+
+ assert doc.metadata == metadata
+ assert doc.metadata["score"] == 0.95
+
+ def test_document_creation_with_vector(self):
+ """
+ Test Document creation with embedding vector.
+
+ Verifies:
+ - Vector embeddings can be stored
+ - Vector is optional
+ """
+ vector = [0.1, 0.2, 0.3, 0.4, 0.5]
+ doc = Document(page_content="Test content", vector=vector)
+
+ assert doc.vector == vector
+ assert len(doc.vector) == 5
+
+ def test_document_with_external_provider(self):
+ """
+ Test Document with external provider.
+
+ Verifies:
+ - Provider can be set to external
+ - External documents are handled correctly
+ """
+ doc = Document(page_content="External content", provider="external")
+
+ assert doc.provider == "external"
+
+ def test_document_equality(self):
+ """
+ Test Document equality comparison.
+
+ Verifies:
+ - Documents with same content are considered equal
+ - Metadata affects equality
+ """
+ doc1 = Document(page_content="Content", metadata={"id": "1"})
+ doc2 = Document(page_content="Content", metadata={"id": "1"})
+ doc3 = Document(page_content="Different", metadata={"id": "1"})
+
+ assert doc1 == doc2
+ assert doc1 != doc3
From 58f448a926174fa90a2d971432dacb218a990c11 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?=
Date: Thu, 27 Nov 2025 14:40:06 +0800
Subject: [PATCH 55/63] chore: remove outdated model config doc (#28765)
---
.../en_US/customizable_model_scale_out.md | 308 --------
.../docs/en_US/images/index/image-1.png | Bin 235102 -> 0 bytes
.../docs/en_US/images/index/image-2.png | Bin 210087 -> 0 bytes
.../images/index/image-20231210143654461.png | Bin 379070 -> 0 bytes
.../images/index/image-20231210144229650.png | Bin 115258 -> 0 bytes
.../images/index/image-20231210144814617.png | Bin 111420 -> 0 bytes
.../images/index/image-20231210151548521.png | Bin 71354 -> 0 bytes
.../images/index/image-20231210151628992.png | Bin 76990 -> 0 bytes
.../images/index/image-20231210165243632.png | Bin 554357 -> 0 bytes
.../docs/en_US/images/index/image-3.png | Bin 44778 -> 0 bytes
.../docs/en_US/images/index/image.png | Bin 267979 -> 0 bytes
.../model_runtime/docs/en_US/interfaces.md | 701 -----------------
.../docs/en_US/predefined_model_scale_out.md | 176 -----
.../docs/en_US/provider_scale_out.md | 266 -------
api/core/model_runtime/docs/en_US/schema.md | 208 -----
.../zh_Hans/customizable_model_scale_out.md | 304 -------
.../docs/zh_Hans/images/index/image-1.png | Bin 235102 -> 0 bytes
.../docs/zh_Hans/images/index/image-2.png | Bin 210087 -> 0 bytes
.../images/index/image-20231210143654461.png | Bin 394062 -> 0 bytes
.../images/index/image-20231210144229650.png | Bin 115258 -> 0 bytes
.../images/index/image-20231210144814617.png | Bin 111420 -> 0 bytes
.../images/index/image-20231210151548521.png | Bin 71354 -> 0 bytes
.../images/index/image-20231210151628992.png | Bin 76990 -> 0 bytes
.../images/index/image-20231210165243632.png | Bin 554357 -> 0 bytes
.../docs/zh_Hans/images/index/image-3.png | Bin 44778 -> 0 bytes
.../docs/zh_Hans/images/index/image.png | Bin 267979 -> 0 bytes
.../model_runtime/docs/zh_Hans/interfaces.md | 744 ------------------
.../zh_Hans/predefined_model_scale_out.md | 172 ----
.../docs/zh_Hans/provider_scale_out.md | 192 -----
api/core/model_runtime/docs/zh_Hans/schema.md | 209 -----
30 files changed, 3280 deletions(-)
delete mode 100644 api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-1.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-2.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210143654461.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210144229650.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210144814617.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210151548521.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210151628992.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210165243632.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-3.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image.png
delete mode 100644 api/core/model_runtime/docs/en_US/interfaces.md
delete mode 100644 api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/provider_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/schema.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-1.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-2.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210143654461.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144229650.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210151548521.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210151628992.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-3.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/interfaces.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/schema.md
diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
deleted file mode 100644
index 245aa4699c..0000000000
--- a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
+++ /dev/null
@@ -1,308 +0,0 @@
-## Custom Integration of Pre-defined Models
-
-### Introduction
-
-After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
-
-It is important to note that for custom models, each model connection requires a complete vendor credential.
-
-Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
-
-
-
-As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
-
-### Writing the Vendor YAML
-
-First, we need to identify the types of models supported by the vendor we are integrating.
-
-Currently supported model types are as follows:
-
-- `llm` Text Generation Models
-
-- `text_embedding` Text Embedding Models
-
-- `rerank` Rerank Models
-
-- `speech2text` Speech-to-Text
-
-- `tts` Text-to-Speech
-
-- `moderation` Moderation
-
-Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
-
-```yaml
-provider: xinference #Define the vendor identifier
-label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
- en_US: Xorbits Inference
-icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
- en_US: icon_s_en.svg
-icon_large: # Large icon
- en_US: icon_l_en.svg
-help: # Help information
- title:
- en_US: How to deploy Xinference
- zh_Hans: 如何部署 Xinference
- url:
- en_US: https://github.com/xorbitsai/inference
-supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
-- llm
-- text-embedding
-- rerank
-configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
-- customizable-model
-provider_credential_schema:
- credential_form_schemas:
-```
-
-Then, we need to determine what credentials are required to define a model in Xinference.
-
-- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
-
-```yaml
-provider_credential_schema:
- credential_form_schemas:
- - variable: model_type
- type: select
- label:
- en_US: Model type
- zh_Hans: 模型类型
- required: true
- options:
- - value: text-generation
- label:
- en_US: Language Model
- zh_Hans: 语言模型
- - value: embeddings
- label:
- en_US: Text Embedding
- - value: reranking
- label:
- en_US: Rerank
-```
-
-- Next, each model has its own model_name, so we need to define that here:
-
-```yaml
- - variable: model_name
- type: text-input
- label:
- en_US: Model name
- zh_Hans: 模型名称
- required: true
- placeholder:
- zh_Hans: 填写模型名称
- en_US: Input model name
-```
-
-- Specify the Xinference local deployment address:
-
-```yaml
- - variable: server_url
- label:
- zh_Hans: 服务器 URL
- en_US: Server url
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
- en_US: Enter the url of your Xinference, for example https://example.com/xxx
-```
-
-- Each model has a unique model_uid, so we also need to define that here:
-
-```yaml
- - variable: model_uid
- label:
- zh_Hans: 模型 UID
- en_US: Model uid
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入您的 Model UID
- en_US: Enter the model uid
-```
-
-Now, we have completed the basic definition of the vendor.
-
-### Writing the Model Code
-
-Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
-
-In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
-
-- LLM Invocation
-
-Implement the core method for LLM invocation, supporting both stream and synchronous responses.
-
-```python
-def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
- stream: bool = True, user: Optional[str] = None) \
- -> Union[LLMResult, Generator]:
- """
- Invoke large language model
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param model_parameters: model parameters
- :param tools: tools for tool usage
- :param stop: stop words
- :param stream: is the response a stream
- :param user: unique user id
- :return: full response or stream response chunk generator result
- """
-```
-
-When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
-
-```python
-def _invoke(self, stream: bool, **kwargs) \
- -> Union[LLMResult, Generator]:
- if stream:
- return self._handle_stream_response(**kwargs)
- return self._handle_sync_response(**kwargs)
-
-def _handle_stream_response(self, **kwargs) -> Generator:
- for chunk in response:
- yield chunk
-def _handle_sync_response(self, **kwargs) -> LLMResult:
- return LLMResult(**response)
-```
-
-- Pre-compute Input Tokens
-
-If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
-
-```python
-def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param tools: tools for tool usage
- :return: token count
- """
-```
-
-Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
-
-- Model Credentials Validation
-
-Similar to vendor credentials validation, this method validates individual model credentials.
-
-```python
-def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return: None
- """
-```
-
-- Model Parameter Schema
-
-Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
-
-For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
-
-However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
-
-```python
- def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
- """
- used to define customizable model schema
- """
- rules = [
- ParameterRule(
- name='temperature', type=ParameterType.FLOAT,
- use_template='temperature',
- label=I18nObject(
- zh_Hans='温度', en_US='Temperature'
- )
- ),
- ParameterRule(
- name='top_p', type=ParameterType.FLOAT,
- use_template='top_p',
- label=I18nObject(
- zh_Hans='Top P', en_US='Top P'
- )
- ),
- ParameterRule(
- name='max_tokens', type=ParameterType.INT,
- use_template='max_tokens',
- min=1,
- default=512,
- label=I18nObject(
- zh_Hans='最大生成长度', en_US='Max Tokens'
- )
- )
- ]
-
- # if model is A, add top_k to rules
- if model == 'A':
- rules.append(
- ParameterRule(
- name='top_k', type=ParameterType.INT,
- use_template='top_k',
- min=1,
- default=50,
- label=I18nObject(
- zh_Hans='Top K', en_US='Top K'
- )
- )
- )
-
- """
- some NOT IMPORTANT code here
- """
-
- entity = AIModelEntity(
- model=model,
- label=I18nObject(
- en_US=model
- ),
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_type=model_type,
- model_properties={
- ModelPropertyKey.MODE: ModelType.LLM,
- },
- parameter_rules=rules
- )
-
- return entity
-```
-
-- Exception Error Mapping
-
-When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
-
-Runtime Errors:
-
-- `InvokeConnectionError` Connection error during invocation
-- `InvokeServerUnavailableError` Service provider unavailable
-- `InvokeRateLimitError` Rate limit reached
-- `InvokeAuthorizationError` Authorization failure
-- `InvokeBadRequestError` Invalid request parameters
-
-```python
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
-```
-
-For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
diff --git a/api/core/model_runtime/docs/en_US/images/index/image-1.png b/api/core/model_runtime/docs/en_US/images/index/image-1.png
deleted file mode 100644
index b158d44b29dcc2a8fa6d6d349ef8d7fb9f7d4cdd..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 235102
zcmeFZXH-+&);3I0P!Z6ZD&2~7REl&I=>pPwmkyzKLPS6iq&KAn1SwKN?+}WBfD~x~
zlF&OO^p+6vM$dV^an28)*Z<$k7(3Y`ti8utbI*CLd0lfR?_a4aQeI=aMnptJsjT!u
zi-_n7;kf+$3K`*(ByK+wBBJZk4svp@l;z|&UU|6MIyl=95h=Y-(I+?1?xW8(*1CP^
z3J1si!v~S$L|V`7iSrpc6qHE9{;}dHn))e!8OBdvky~;niuHzsm7S*Z6axi!f0IE4
zkE@m}47d`whFzV-UwE#{*bQYFlM!LAX1_>`b|i`CSK*W4ib
zoLT=-=V5~N)SFn^2Uqrj#D;Vhxy97=7p1O&!=6%o<0Fcn@RxLWK*W@Gdzg38?m1NdhC0IE@vwr6q{?o5KzGbM~v>Jb-!_YVl8?v6t-Q%X^Q>D?RwH-6em2?OtOG
z7v4_nqJCwipUldjt04P6WpvW;tx1%vdT|r2DO7bZ
zc?#zwRHe5@LZ=MdDWcv?U0J;04$#Su(ipxG?soef&!sG0?&SxvrPO;tme1T;`t9%7)`TrnoOyG`>di#Ce;iCVki2fDZlRp^N#b;I8-l@i(1FzUUXqnm3$ew{_xa
z*mL)98D@6BXWC8ZUtB&je$K*48$rf>RQdV3gHmQclow)Q(Kgk?ds9>;{DjW-ei+wQ
zXzi|l)cJx*#lZ7N67I5!27=-q7W^cZuPQvyQ%`cAJ~I1q@Jx%5v5+NN`+nmiZV|Ga(M;T{rYcKhMmbJ#OIf3Hio
zqGs=ON3!f*T9mv&?6FY$@{x)pBi~iYcOTZ-fBIMKS++G+^kMN*GVQ^3Q}OS09#w@z
zzQnxYeDsm=#r3=IZlK)c5Y#6fcNZQJi7;e_2U7l6YL;2Jnwzy7@0K;9+#vi+=MIi*
zKr@GCGor{b;li5F&8cW3mVB}86)X2rK*MdfZ%7Lk5j6GPX1Ub7eq%GbRgD=?MhCU8
z+=`-mK~D1RWTGh4@Ll0;VFbDEa_vQ>bic_M`PSgA?K3gywTMlMd9g)14QJ^(c8Tb$
zZ$s9K6=dz-Vy$BgczkqX$@O8@-J492^o8F!$u}b}Zo1A7_463cn`|La=kn9~4>Nf|
zilpjoucLI5Kk0sd-6ilgzK`x8>)j*MOj8r$c#wHA1R`1)wL|t9hRMn}ASY_;#bUA4
zpta68(h3-j^hHFQXC3e9#w*f!#O=11M)}KbXCXv$|9G9|-y)*1Aa-5|GI>M3OH9nc
z7#3i^s_iNj^v*;X9X=Hl+oY89SDs!SdJ{eWxPc1Y
zJ~4LL?2XwjRgl7su&|LEWL9rJ-qDLBU3^<{L*^di`xlGxwD&0cUVMwZ{^{M1XZ;B_
z-;G|=XoYUeQ{CoFI8c)Cpzwd=br+JL@$lW(Jn^DOjxkhthpV~Y4<4nqgtkliIRyTLQV3SGpF>O;{>I9
zq=}?`X|!sK3D=arzqukM#2d}D&tBXsDy~oGKo5w{R9UWI4%^WKo
z@6;MPmP{NMWoZZI9BY~yBh9k%r1eY-+LVA9n5##|o38dk>B4+(QdM~J`94pKsQ*Nd
zN!tqwWePPgk;HsrN?|%;(ibut`YY&ydbGdkM+eDK!6lVy`7sCM=dNIe05QB|Sn&
zLI_*KXmq#cbM1Z&y(4Nw9CGwBBj2TrwUlopyQI|+n|yqitf#`ljWMW)rH8)9K_QJw
zvsWRaB&kfhyjrzTtx#Xz7^Wtw>{|u+g4RT;OT8R6iYrMgsW)&1&r-H($bE>Yh`18L
z%cv!~+#i2}kvk$i49D1DnvWn`pvM+Wmzb9sd4=r-8|)&PXPA4K)E}F;`ps0>6%D-t
z=bq=<$emEOx`Z)JGk4!o6P|G%cO18KwqkMQePm`G)ZftiV)$PBJ^1VJzDG#_?RUlC
z;_qJJ8}4IkFHIk0d>SwrbhbQBzx65nQ{NAfPqm+bBK40ngajVVxdactb?m76USVMQ
z)A^^9gQHl#N4ibfiVIH^-2nH1W6FeK$;qV1WZ&dLDRU`!dL8cO66NCPf}R<5#<_6L
z=)ujZ@6J?idu&fmo7O3!kf`cy9+ZGs4iy{Kod|C#RM(ZRmUr(Xbg8%8bbMb+(2Grq
zOS_GFMtO0&JGupkAw2vxcfWKEQ{h*)(V1H<-Dg8rP&XIuswfD>_@S(JI(>`18dp_3
zDc!R+X4WzNgB@z?b+f1fi7sF=R`SI51j3fhrWGe12hdO*78>Rrrb?yac~=|nJrq4p
zw=O)k6$w*xrQUhI1Dpk&J;f}aZtwbwjN0bAEvF{L1q}lI+cJVAa36ymf|+sRxWRM(
zi-`+rQeKiv#KxrARj8@y?X2yzkfD_Rx?WvDwUxOqyo~uk2qLzA~=xakJW;naAE<
zz5VHS?d{fFbgs6XM%c~HJDsWTyxn3Ng8-EEqKGF!0e-4s0WX5S%|u2d94Cx^FVI{s
z+SMFTM8;G_{i7)SBGevmxW69wNaPWXHC69F*49=$*5Ge3ks0*9BH1@(ZU-x#H;1){
zbzk-U$A!k`{h;z;exosQTXRK
zI1FyD(_^lsT?N~o+VT$grEX+2N>)n@$dueEQT%qZjIOnxYG5`VC(S8tXRHnac%ade
z9%e?fw=3(ODk2hH6J^Apri&b0DPHP_*$p065TWNywM1Nu`XdHKq@ZPEo4P9V*VIwx
zQQXneFICcOKvk2k_m(HQjL}Bwu+p<=(TLz8I{Y8$?
zCa~Ubz4a%_OCfNJmfWDO!kEu7qN3q|uaLU(@~M{ylbW->)Os_YlZkL?|Ao$Qnq5S)*L*3Vr2T`=v45-^H-ip8e<(@>Jjy_`U@JP0fq%hFv;K#~z0y-W&Fx)G-6L8yqud(n@qfDm6Y^F_eSHIo7tTivN
z*-Y8|2x@lPGe1~Y4{*)ET4B)8fxwe9;52zCHC(1jIvbaKzOk{UY_4Rk2dc%+9P%x~
zkryo&!B;i3vROH^eMbt7&M#OiShztSN?0t9wx+a$#sCStJf0gGX^4bAW4%G#;
z_qZLJK}B@z90C_8(qN~My-xY&lE9m$CaSuS)^sa#F3~#(a_5)l>Dg&i7cpEXnKH|`
zgInKqk?q54AmZHFAUXa-UMm8jaJDf}wpCLj;wGG5A-Y7&M0A;OMoc)QiJAYua|Pmu
zM5MnzCm|w=a3H$$?=k9x_g{bUgyYvW|9U6MA^D#XSG00S|L6Sj_g_P~SZwDBZ&%%v
z3_Xd6=FoyUiS^&LjZb2*xGx2N;33-Le
z&CN9{lJND3wFhah-BRFA{NKaxN4C_m?I4raSfPfmK-+0z;OZwkMD?%~`O9dM
zB_EFNmV==!Ju`>0f2c;cwCLcz_p0
zMox~T5IBe9aFtN3Khvjh@ep^>c<KZNxR6()mpi~T
zYx)!BC>UCd<}0IPTusLDEr+A3=wZ-BG!|Ap9P7askk{ZY9RFM+$G|2fB~UKb=+0l@
z#z-hhfJ*l}X25(l=)(U58^PxD#P-K}hn^RtABP-hI?61(FX2$FCTep-e(UwO|8;Wz
zMoNnpB+ME;N7SA>(JP9-LHc{UJ)-=ZA_$96F*fT_lO)bK@wW<9yZ53r?Fwc8<_1jH
z>;=2Uogh+*&{gYkuu|oGjRPR`25dYuzDMGw_y`*PoLnms}pZs-@Mm$J|aDCHeKOi@pK64~Ti0Uuy5n}eQ$Ksn+
z8`yfQ>npy9GXJApwm)LCq)fb2CGpB%+=KX%g1$w9JiWWCiPz7cd|oWcBHW~0@7}#b
zsPL2vQl+cX{nx|$(;ahkLsIn3rZ=+jnL^&)9N~gTPv(NOkx@MSGMtU0dVljo0?aIC
zYAz%0MDuR%;D)(Kxtm1d?~s|;r051*%;ymQP5A$KPck4-A2hS=$1BV9w1fVEcaxNj
zJq2HbMPc-9nZKFD4aUpFQVDmmolRuT&COXPEmI{hFRpb(dDFin(>y%#t@-#gF!ko&
z_<2q=Nzwx~k7)iTd(*2n)|WQ#99hUJDm=YXc+I}}e%kso`mZ1U%FiBDoLafobR$}R
z`h}Q$V>7z-vqb7EnZ`F1Sj5GsrT_7NZbg3t-ouZw)q!}CV2e8~XzMABe8DhVtPh&Y
zQ#88M=kM&&OLFkziH2&elQ-`()lV^+`IeOV_Js?iq;~_>yBRBA@&463KFUg*T)=N>
zTb#ue`ol@N0$Mwku7aEXIPd@X#|XtQaCgzIPV1$w5stTdO3A98UU&HZ@2t3_o}=GO
zQ6m}uS)chSI+m6@>@SSp&T?lY+M3dcP3cj@wp{xAzqo{m#Ir(Hm9Q}PuDUYtC
z5a#!Vcy#mT(hqV09P*qzh`yEpYhP{E
zzq#;1fh=b)k=J90y3-n&mRHs1tRE#~F%SMOdhN|~lB|ivKj~xniq#XS@);p?rRCzr
ze+OY{mwT2B!o3I+w(PZ*lbDv?fBqMb!9tzx;OM>^C%M_S`(+GQH=$?RxFH(_lHSW7
zBZ^&6hqpR0i_VP~f9V6Olf7uM$vDY}qni&z0gX0k>aY@e{xdWk`h;3)Q99tX5OI|=
zYqK1);EP}?fh^|D;oIdLu(YpOlyXDJCWBpMPUW!P&K%v_)741L-pqBg;1xO>R5^Bn
z5)gNS+SW4&HiXWi@N8#YdFap4P%ynGN&w;;xKpXxG(_@eC$r|uc|-&R0!<-3_9@yE
z@#U2H7B}SS#1{T#spZRuVeepOfl`z~n;W1n0|Mt*`IA#T?ujYrxN&W4<62A;!G
zzA1eFUZW5lBTMGWbbiJmGXQgpD9~UDt{Bti6_;YYcyB$%DQt2+
zSp7~aevwlKC!Zd->DPV{*m$;<0U4VZT=B)So-|ZL_XbT2Ci>?~4I9-M%smh2Q6&Sq{7>{5O#9)yD+q)@0*7%M
zBCU0u4_bY`6%2|7CK=m)`f|Rx7|oadqx$?gOBr0_N?$~;<#g51&9-24ebt-S8*gyO
zDi?7FacOk5DMwZ;Uj_;a!Tvos3?`%cjY!&za`T?T*8N6fp!2M4$la*&$Gx8sH+9p$
z{UIXT8LxTah3TxP8p`RVQb}2xW-9I-#8Tyd)F#u+4D9GeS91pU?PO51dgPBu49XB9
zA7EjT0{}^#Eo-9>tA`*IkZT5RGc~DXT0uu2JWHz9l^SNcerF3Kz>tdz>H3+kJk4uh
z&~3i$Mtj2$SA*9JsdZeV)cr;H4~Q`Li*U3wSSw=+JCu7#PYcG(U6NToa7EO|2OVpKFIE#?1Q
z+b2)NmdJUM48(XB%<+-psBw
zZ5qemW!*bYrS<;l-AQq{*`uQ)>``v+G@x$
zbmqfBczBlIzP?I54jmGa87zdVz)5(La;;_uqJ;)OyBgdxoqT?Nm?vnT6Ugb+_*SvO
z1J3H(>;!3+9<&)46!qRJZ7pctScN0X15>H5y;`=>r-IInhJ}ag!e)6H_J^TT>jf|S
z&OMl&c)$`D_M-TpgA|udzmH~Vx*|Z{TUTyEAHP_;Nsfm4n(DOnl$_Fj#hoC#bXSbi
zi2uyDC@(EKZHp9juI|;58D6(eqKvfn7L=Eav|{W|c+E({dP_+YClrNdHR-Z$-_&U@tN=G#usmPV@%Vc++>Mz8iZ(50G-PsrIR>
z2=5ahx{BRog+~y!+y?+l`5q0NAn}2PV81GtYuM9-1>Yh14A(^;;0~3$P@#TU5RQ3H
zrE3x9_^`q;zW0+kv>oY#MM@({D1NNPKHa|Sz}96r!T}|fKHM|WqP#aE1XQ?qPt&Mm
zU9*qbf;1E27{&)hJpQ#QvuW}P-B?`FH!B;k&Z?c;l{v3-1+6_kX*2aLjrJWV
zt6wvV@E9RKQ6`f+X-vDwzyl#+6yw2#Fq`LpO#fslNPJx3ggp3SgEruwaLUgW{?#3{
z-Op(It(`f=K56PD0i?w~XZYLkeHV!FTa6w4D_+
z0gq<~MX-%Xj>`5$b9&=brPW@RKo6UJ1&vQro^uIE#uDCS%HNo3W|8e*(14Svu=~1K
z?G(I0F8WXwvf_7O%Zzn#f;2Wk-SNnQ@|R0>%{hhIwB%a$L|w4qK{74cgVGtFHPaI@
z+|*%ZXbG6n(w%8API5J@21)PQO_y5T9Xo`};8TcS@@FvFZaYk=PU@hBG*Gi25}$+*
z1Dk|IU0|}wgCf)4cY5#Bo0kgBKEVe}mQ_DHNHU`IFm$htQ8|bt6=+#+R+q6ElexgW
z8GF?GM;a&iPl#X68?x(DM9~S$w?=1jt}?%{^rR3-eI*k4nH%uPrGB#d8D&+U#&lhf
zpFlamc=7@wO-bI9F@mi|)?gf!dzxuLwy{&pgvYeGCkis=O<6NS{V!k3yqyD63|fii
zJnc~jOLuJ)lZv?1uGpfJlj3T1eKCe>g#$DW&orh-XfnPQsLj>nV$U?%(g2e@{EHzZ
zdm{2pEA_rj0$?8>CVc-uTM8B1h(+RQ?L&urGKv69m|2=#vv9E&18X9KtQSd2yyBi0R)~w}X2*wV|_)c6Hj)zi9&EPI_
z6@%IzFuTsT9j7E!#n7n&@8UdXqk^s~5OR&x*CFP()dG`)+Tb6W5ls$b^xwmsY<8r4
z+M
zwIt5xFF!r@f!u5dwn7^yZA4CM><@%bvpL$jr#hz6bd8H
z7+ue;(*MvF{+l}=HpxnU{9Nr)S*S8a!{rQESOVUV7b&{k7IE`Kl}nbeHURmf*#Lf%
zGXJ41{p5kDfB({c@d|?S_fyyt>rN
zX~ACP?<{Miipd|R{k$|
z4Pue
z1z5eK%r^kF)_PucY=1E=87oz`eN73K~Ntw!|5i4VELmw*z@0-%Lzz~;FBSPe9A{&qd8`2a$zDq8j5=Khc8
zFhWes{4zK2TUIMIAoPLrjqpL;4L|>8U*{EG06EimC9i_#E4f~7q=G^ho06t*PEJf!
zApa)4E3#vQb=77J>x29WTq+gd1aV;nKuha&@SP!fY|wt}k+7cmeFj?|3DhZ|@EMiJ
zG|j!5wuw~NIt!V|B}IMHq;hJzT;!p-3ZQ*@X(&RQ&EdCXkQw-kY95Ks8Y#M?ykX6r
zL(Sh=1HOkex`1$7*6K1@nE>d)>|qLBzwcJP^Q_)6L0At9yf|YBO4l<3=gy@y{gh#Q
zKzZkVVaW1TW>7Jy)I3-{aV~51#I?FX1?@mD%TXOP%0`Oua5X0m_8+
z58!9if6`|ygf1)i3Wghrn2~D=X@jqX@?tWCt6EW6tl?v3)5tg}>(w6x;k<(Qb2JnW
z6k~9O0#Dr7#nwq1EBtICP;&{0k$$q5-{6ZjQhQc?_vzz84SogxjdfyTsV_y27{r=}
zXFW23J*cqZKrL(Lu#kJ$sv$eTt;cX}?X4)n0*MHyPJsiOdT@KWa9o^J+2WV9edx}j
z%Ji8rrz+U~QI5}1
zB|)~5mlvwpx4&@T41Y$uut0%4;%fRZNy`52=tb0i|B-In@#m?A_bt*pEe8z){9eF!5pC|nnka>_#pu?brv7L0=*r%UVd$jwWa;+1LE}1vklNP}S}1g4;4JXy&~s~lMD~4}xqo-WPU}twf$h11dw$Yz
zt3O2-y*AA=7ES?&1)+X;BDu_^pS*C`#s36L_aOt|Sq|4*A3N?Z_iY65P+*X*qPs`U
z07os5)cxDptLZ!go)$TC^&5#hW((!#9rEaKh6@@+)`=DU!ZmWOb!F87{_13qO@vwH
zXRhGZ-X7+%ox`tuYBP#Vhc}?Hd8gyWgsm$f4&D!{wjK0X10DS1yS-Pq{8^xJc@e$I
zBjC=%3w35{>&p=F`X!UVPu|3>WWU+Vm5V)uTnaoWU@mkw{3ElGwcx$vg_k|4tPW(#
zr8hh^7pEr{vM4QWC`9yQV>dl(rJ1ZKOpAoLI$}{`Rb-VqXXy8-eK2hGzb}&
z=?rd%osCbS*u>naYyu}jSxz8g;oM6OUC{!q%X&sQHOi%L%&_U4;CtFFrUM|{5)v>&
zLP7g)hRj&d&d}x;AZ~GI=(z^->DCWIoQGPtuUr`5in!6Nsfcb3QR8Y|gI)1neqr-k#@g
zTlRpW8rRI-FdN@$boV!l?#cvLU^j0q6f>ro1;rW%6yJvr8RaR@wGA30G=eeFqcsGg
zQlBB@z)_DK^u6Hsfbp!&ad0u7AsJClra7PNa+v+R`)5
z_ph(A@b7Wg5MN+C8P^^@1wWNE3&zhjH_maIK|$*YLQl&TnZEx~ACqOcB;?nbeFxdN
z9Ub~`n~q&1vU7C_C@6n?_`XGxepc+s%6M7Ol2M&h{2g_Kb7Rnme6@}