mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: unit test cases for console.datasets module (#32179)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
This commit is contained in:
parent
03dcbeafdf
commit
8906ab8e52
@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
|
||||
console_ns.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
||||
@ -0,0 +1,817 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceAuth,
|
||||
DatasourceAuthDefaultApi,
|
||||
DatasourceAuthDeleteApi,
|
||||
DatasourceAuthListApi,
|
||||
DatasourceAuthOauthCustomClient,
|
||||
DatasourceAuthUpdateApi,
|
||||
DatasourceHardCodeAuthListApi,
|
||||
DatasourceOAuthCallback,
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-1",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-123",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
|
||||
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {"name": "test"}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_callback_missing_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_invalid_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {} # avatar + name missing
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"reauthorize_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "/oauth-callback" in response.location
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
side_effect=CredentialsValidateFailedError("invalid"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
"enable_oauth_custom_client": False,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
"name": "x" * 101,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
@ -0,0 +1,143 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
DataSourceContentPreviewApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = preview_result
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=payload["inputs"],
|
||||
account=account,
|
||||
datasource_type=payload["datasource_type"],
|
||||
is_published=True,
|
||||
credential_id=payload["credential_id"],
|
||||
)
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
# datasource_type missing
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
CustomizedPipelineTemplateApi,
|
||||
PipelineTemplateDetailApi,
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [{"id": "t1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in&language=en-US"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == templates
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
template = {"id": "tpl-1"}
|
||||
|
||||
service = MagicMock()
|
||||
service.get_pipeline_template_detail.return_value = template
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response = method(api, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert response == {"result": "success"}
|
||||
@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
CreateEmptyRagPipelineDatasetApi,
|
||||
CreateRagPipelineDatasetApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
import_info = {"dataset_id": "ds-1"}
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@ -0,0 +1,324 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
RagPipelineSystemVariableCollectionApi,
|
||||
RagPipelineVariableApi,
|
||||
RagPipelineVariableCollectionApi,
|
||||
RagPipelineVariableResetApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db():
|
||||
db = MagicMock()
|
||||
db.engine = MagicMock()
|
||||
db.session.return_value = MagicMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restx_config(app):
|
||||
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
|
||||
|
||||
|
||||
class TestRagPipelineVariableCollectionApi:
|
||||
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = True
|
||||
|
||||
# IMPORTANT: RESTX expects .variables
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
draft_srv = MagicMock()
|
||||
draft_srv.list_variables_without_values.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineNodeVariableCollectionApi:
|
||||
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_node_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
def test_get_variable_not_found(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(id="p1", tenant_id="t1")
|
||||
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
payload = {"value": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineVariableResetApi:
|
||||
def test_reset_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableResetApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
workflow = MagicMock()
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
srv.reset_variable.return_value = variable
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
|
||||
class TestSystemAndEnvironmentVariablesApi:
|
||||
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineSystemVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_system_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_environment_variables_success(self, app, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
env_var = MagicMock(
|
||||
id="e1",
|
||||
name="ENV",
|
||||
description="d",
|
||||
selector="s",
|
||||
value_type=MagicMock(value="string"),
|
||||
value="x",
|
||||
)
|
||||
|
||||
workflow = MagicMock(environment_variables=[env_var])
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
@ -0,0 +1,329 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
RagPipelineImportApi,
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
RagPipelineImportConfirmApi,
|
||||
)
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
"yaml_content": "content",
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"status": "success"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"status": "failed"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
result.model_dump.return_value = {"status": "pending"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"ok": True}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"ok": False}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
result = MagicMock()
|
||||
result.model_dump.return_value = {"deps": []}
|
||||
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"deps": []}
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": {"yaml": "data"}}
|
||||
@ -0,0 +1,688 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
DraftRagPipelineApi,
|
||||
DraftRagPipelineRunApi,
|
||||
PublishedAllRagPipelineApi,
|
||||
PublishedRagPipelineApi,
|
||||
PublishedRagPipelineRunApi,
|
||||
RagPipelineByIdApi,
|
||||
RagPipelineDatasourceVariableApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result == [{"id": "p1"}]
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?q={}"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "llm")
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
node_exec = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
service.get_node_last_run.return_value = node_exec
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
service.set_datasource_variables.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result is not None
|
||||
@ -0,0 +1,444 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.datasets import data_source
|
||||
from controllers.console.datasets.data_source import (
|
||||
DataSourceApi,
|
||||
DataSourceNotionApi,
|
||||
DataSourceNotionDatasetSyncApi,
|
||||
DataSourceNotionDocumentSyncApi,
|
||||
DataSourceNotionListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_ctx():
|
||||
return (MagicMock(id="u1"), "tenant-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_tenant(tenant_ctx):
|
||||
with patch(
|
||||
"controllers.console.datasets.data_source.current_account_with_tenant",
|
||||
return_value=tenant_ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
with patch.object(
|
||||
type(data_source.db),
|
||||
"engine",
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDataSourceApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
binding = MagicMock(
|
||||
id="b1",
|
||||
provider="notion",
|
||||
created_at="now",
|
||||
disabled=False,
|
||||
source_info={},
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: [binding]),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0]["is_bound"] is True
|
||||
|
||||
def test_get_no_bindings(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: []),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"] == []
|
||||
|
||||
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "enable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is False
|
||||
|
||||
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "disable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is True
|
||||
|
||||
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "disable")
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
dataset = MagicMock(data_source_type="notion_import")
|
||||
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = [document]
|
||||
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
dataset = MagicMock(data_source_type="other_type")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"integration_secret": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.NotionExtractor",
|
||||
return_value=extractor,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "p1", "page")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_indexing_estimate_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"notion_info_list": [
|
||||
{
|
||||
"workspace_id": "w1",
|
||||
"credential_id": "c1",
|
||||
"pages": [{"page_id": "p1", "type": "page"}],
|
||||
}
|
||||
],
|
||||
"process_rule": {"rules": {}},
|
||||
"doc_form": "text_model",
|
||||
"doc_language": "English",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
|
||||
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
|
||||
return_value=[MagicMock(id="d1")],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_document_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
1926
api/tests/unit_tests/controllers/console/datasets/test_datasets.py
Normal file
1926
api/tests/unit_tests/controllers/console/datasets/test_datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,399 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.external import (
|
||||
BedrockRetrievalApi,
|
||||
ExternalApiTemplateApi,
|
||||
ExternalApiTemplateListApi,
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_external_dataset")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
user.is_dataset_editor = True
|
||||
user.has_edit_permission = True
|
||||
user.is_dataset_operator = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(mocker, current_user):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.external.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
api_item = MagicMock()
|
||||
api_item.to_dict.return_value = {"id": "1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
|
||||
def test_post_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_duplicate_name(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalApiTemplateApi:
|
||||
def test_get_not_found(self, app):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_api",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "api-id")
|
||||
|
||||
def test_delete_forbidden(self, app, current_user):
|
||||
current_user.has_edit_permission = False
|
||||
current_user.is_dataset_operator = False
|
||||
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "api-id")
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
dataset.embedding_available = False
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset.is_published = False
|
||||
dataset.enable_api = False
|
||||
dataset.enable_qa = False
|
||||
dataset.enable_vector_store = False
|
||||
dataset.vector_store_setting = None
|
||||
dataset.is_multimodal = False
|
||||
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset.tags = []
|
||||
dataset.external_knowledge_info = None
|
||||
dataset.external_retrieval_model = None
|
||||
dataset.doc_metadata = []
|
||||
dataset.icon_info = None
|
||||
|
||||
dataset.summary_index_setting = MagicMock()
|
||||
dataset.summary_index_setting.enable = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
):
|
||||
_, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_create_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApi:
|
||||
def test_hit_testing_dataset_not_found(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "dataset-id")
|
||||
|
||||
def test_hit_testing_success(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"query": "hello"}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
||||
patch.object(DatasetService, "check_dataset_permission"),
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
"external_retrieve",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp = method(api, "dataset-id")
|
||||
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestBedrockRetrievalApi:
|
||||
def test_bedrock_retrieval(self, app):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "hello",
|
||||
"knowledge_id": "kid",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetTestService,
|
||||
"knowledge_retrieval",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp, status = method()
|
||||
|
||||
assert status == 200
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApiAdvanced:
|
||||
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_get_with_pagination(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
def test_create_forbidden(self, app, mock_auth, current_user):
|
||||
"""Test creating external dataset without permission"""
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
current_user.is_dataset_editor = False
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "ek-1",
|
||||
"name": "new_dataset",
|
||||
"description": "A dataset",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
|
||||
"""Test hit testing on non-existent dataset"""
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
dataset = MagicMock()
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": {"type": "bm25"},
|
||||
"metadata_filtering_conditions": {"status": "active"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.HitTestingService.external_retrieve",
|
||||
return_value={"results": []},
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
|
||||
assert resp["results"] == []
|
||||
|
||||
|
||||
class TestBedrockRetrievalApiAdvanced:
|
||||
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "test",
|
||||
"knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
|
||||
side_effect=ValueError("Invalid settings"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
@ -0,0 +1,160 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from controllers.console.datasets.hit_testing_base import HitTestingPayload
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_hit_testing")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass all decorators on the API method."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.login_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.account_initialization_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
|
||||
return_value=lambda *_: (lambda f: f),
|
||||
)
|
||||
|
||||
|
||||
class TestHitTestingApi:
|
||||
def test_hit_testing_success(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "what is vector search",
|
||||
"top_k": 3,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"perform_hit_testing",
|
||||
return_value={"query": "what is vector search", "records": []},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
|
||||
assert "query" in result
|
||||
assert "records" in result
|
||||
assert result["records"] == []
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
side_effect=NotFound("Dataset not found"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
side_effect=ValueError("Invalid parameters"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
method(api, dataset_id)
|
||||
@ -0,0 +1,207 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import (
|
||||
DatasetsHitTestingBase,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account():
|
||||
acc = MagicMock(spec=Account)
|
||||
return acc
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_current_user(mocker, account):
|
||||
"""Patch current_user to a valid Account."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing_base.current_user",
|
||||
account,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
class TestGetAndValidateDataset:
|
||||
def test_success(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
):
|
||||
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
assert result == dataset
|
||||
|
||||
def test_dataset_not_found(self):
|
||||
with patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
def test_permission_denied(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
side_effect=services.errors.account.NoPermissionError("no access"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="no access"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
|
||||
class TestHitTestingArgsCheck:
|
||||
def test_args_check_called(self):
|
||||
args = {"query": "test"}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"hit_testing_args_check",
|
||||
) as check_mock:
|
||||
DatasetsHitTestingBase.hit_testing_args_check(args)
|
||||
|
||||
check_mock.assert_called_once_with(args)
|
||||
|
||||
|
||||
class TestParseArgs:
|
||||
def test_parse_args_success(self):
|
||||
payload = {"query": "hello"}
|
||||
|
||||
result = DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
assert result["query"] == "hello"
|
||||
|
||||
def test_parse_args_invalid(self):
|
||||
payload = {"query": "x" * 300}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
response = {
|
||||
"query": "hello",
|
||||
"records": [],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == "hello"
|
||||
assert result["records"] == []
|
||||
|
||||
def test_index_not_initialized(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=services.errors.index.IndexNotInitializedError(),
|
||||
):
|
||||
with pytest.raises(DatasetNotInitializedError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_provider_token_not_init(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ProviderTokenNotInitError("token missing"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_quota_exceeded(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=QuotaExceededError(),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_model_not_supported(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_llm_bad_request(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=LLMBadRequestError("bad request"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_invoke_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_value_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ValueError("bad args"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="bad args"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_unexpected_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=Exception("boom"),
|
||||
):
|
||||
with pytest.raises(InternalServerError, match="boom"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
@ -0,0 +1,362 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.metadata import (
|
||||
DatasetMetadataApi,
|
||||
DatasetMetadataBuiltInFieldActionApi,
|
||||
DatasetMetadataBuiltInFieldApi,
|
||||
DatasetMetadataCreateApi,
|
||||
DocumentMetadataEditApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_dataset_metadata")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
ds = MagicMock()
|
||||
ds.id = "dataset-1"
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass setup/login/license decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.enterprise_license_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetMetadataCreateApi:
|
||||
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "author"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"create_metadata",
|
||||
return_value={"id": "m1", "name": "author"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["name"] == "author"
|
||||
|
||||
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
valid_payload = {
|
||||
"type": "string",
|
||||
"name": "author",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=valid_payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataGetApi:
|
||||
def test_get_metadata_success(self, app, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_dataset_metadatas",
|
||||
return_value=[{"id": "m1"}],
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_metadata_dataset_not_found(self, app, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataApi:
|
||||
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "updated-name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_metadata_name",
|
||||
return_value={"id": "m1", "name": "updated-name"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "updated-name"
|
||||
|
||||
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"delete_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldApi:
|
||||
def test_get_built_in_fields(self, app):
|
||||
api = DatasetMetadataBuiltInFieldApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_built_in_fields",
|
||||
return_value=["title", "source"],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["fields"] == ["title", "source"]
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldActionApi:
|
||||
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataBuiltInFieldActionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"enable_built_in_field",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, "enable")
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDocumentMetadataEditApi:
|
||||
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DocumentMetadataEditApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"operation": "add", "metadata": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataOperationData,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_documents_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
@ -0,0 +1,233 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.datasets.website import (
|
||||
WebsiteCrawlApi,
|
||||
WebsiteCrawlStatusApi,
|
||||
)
|
||||
from services.website_service import (
|
||||
WebsiteCrawlApiRequest,
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
WebsiteService,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_website_crawl")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_auth_and_setup(mocker):
|
||||
"""Bypass setup/login/account decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestWebsiteCrawlApi:
|
||||
def test_crawl_success(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {"depth": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
return_value={"job_id": "job-1"},
|
||||
)
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["job_id"] == "job-1"
|
||||
|
||||
def test_crawl_invalid_payload(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "bad-url",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid payload"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
|
||||
method(api)
|
||||
|
||||
def test_crawl_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
side_effect=Exception("crawl failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWebsiteCrawlStatusApi:
|
||||
def test_get_status_success(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
return_value={"status": "completed"},
|
||||
)
|
||||
|
||||
result, status = method(api, job_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_get_status_invalid_provider(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid provider"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
|
||||
method(api, job_id)
|
||||
|
||||
def test_get_status_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
side_effect=Exception("status lookup failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
|
||||
method(api, job_id)
|
||||
117
api/tests/unit_tests/controllers/console/datasets/test_wraps.py
Normal file
117
api/tests/unit_tests/controllers/console/datasets/test_wraps.py
Normal file
@ -0,0 +1,117 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
class TestGetRagPipeline:
|
||||
def test_missing_pipeline_id(self):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(ValueError, match="missing pipeline_id"):
|
||||
dummy_view()
|
||||
|
||||
def test_pipeline_not_found(self, mocker):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
def test_pipeline_found_and_injected(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = "pipeline-1"
|
||||
pipeline.tenant_id = "tenant-1"
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result is pipeline
|
||||
|
||||
def test_pipeline_id_removed_from_kwargs(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
assert "pipeline_id" not in kwargs
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
def test_pipeline_id_cast_to_string(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
def where_side_effect(*args, **kwargs):
|
||||
assert args[0].right.value == "123"
|
||||
return Mock(first=lambda: pipeline)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.side_effect = where_side_effect
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id=123)
|
||||
|
||||
assert result is pipeline
|
||||
Loading…
Reference in New Issue
Block a user