mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 03:36:36 +08:00
test: migrate rag pipeline import controller tests to testcontainers (#34305)
This commit is contained in:
parent
88863609e9
commit
9b7b432e08
@ -1,5 +1,11 @@
|
|||||||
|
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||||
RagPipelineExportApi,
|
RagPipelineExportApi,
|
||||||
@ -18,6 +24,10 @@ def unwrap(func):
|
|||||||
|
|
||||||
|
|
||||||
class TestRagPipelineImportApi:
|
class TestRagPipelineImportApi:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, flask_app_with_containers):
|
||||||
|
return flask_app_with_containers
|
||||||
|
|
||||||
def _payload(self, mode="create"):
|
def _payload(self, mode="create"):
|
||||||
return {
|
return {
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
|
|||||||
method = unwrap(api.post)
|
method = unwrap(api.post)
|
||||||
|
|
||||||
payload = self._payload()
|
payload = self._payload()
|
||||||
|
|
||||||
user = MagicMock()
|
user = MagicMock()
|
||||||
result = MagicMock()
|
result = MagicMock()
|
||||||
result.status = "completed"
|
result.status = "completed"
|
||||||
@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.import_rag_pipeline.return_value = result
|
service.import_rag_pipeline.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/", json=payload),
|
app.test_request_context("/", json=payload),
|
||||||
patch.object(type(console_ns), "payload", payload),
|
patch.object(type(console_ns), "payload", payload),
|
||||||
@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
|
|||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||||
return_value=(user, "tenant"),
|
return_value=(user, "tenant"),
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
|
|||||||
method = unwrap(api.post)
|
method = unwrap(api.post)
|
||||||
|
|
||||||
payload = self._payload()
|
payload = self._payload()
|
||||||
|
|
||||||
user = MagicMock()
|
user = MagicMock()
|
||||||
result = MagicMock()
|
result = MagicMock()
|
||||||
result.status = ImportStatus.FAILED
|
result.status = ImportStatus.FAILED
|
||||||
@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.import_rag_pipeline.return_value = result
|
service.import_rag_pipeline.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/", json=payload),
|
app.test_request_context("/", json=payload),
|
||||||
patch.object(type(console_ns), "payload", payload),
|
patch.object(type(console_ns), "payload", payload),
|
||||||
@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
|
|||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||||
return_value=(user, "tenant"),
|
return_value=(user, "tenant"),
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
|
|||||||
method = unwrap(api.post)
|
method = unwrap(api.post)
|
||||||
|
|
||||||
payload = self._payload()
|
payload = self._payload()
|
||||||
|
|
||||||
user = MagicMock()
|
user = MagicMock()
|
||||||
result = MagicMock()
|
result = MagicMock()
|
||||||
result.status = ImportStatus.PENDING
|
result.status = ImportStatus.PENDING
|
||||||
@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.import_rag_pipeline.return_value = result
|
service.import_rag_pipeline.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/", json=payload),
|
app.test_request_context("/", json=payload),
|
||||||
patch.object(type(console_ns), "payload", payload),
|
patch.object(type(console_ns), "payload", payload),
|
||||||
@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
|
|||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||||
return_value=(user, "tenant"),
|
return_value=(user, "tenant"),
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
|
|||||||
|
|
||||||
|
|
||||||
class TestRagPipelineImportConfirmApi:
|
class TestRagPipelineImportConfirmApi:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, flask_app_with_containers):
|
||||||
|
return flask_app_with_containers
|
||||||
|
|
||||||
def test_confirm_success(self, app):
|
def test_confirm_success(self, app):
|
||||||
api = RagPipelineImportConfirmApi()
|
api = RagPipelineImportConfirmApi()
|
||||||
method = unwrap(api.post)
|
method = unwrap(api.post)
|
||||||
@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.confirm_import.return_value = result
|
service.confirm_import.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||||
return_value=(user, "tenant"),
|
return_value=(user, "tenant"),
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.confirm_import.return_value = result
|
service.confirm_import.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||||
return_value=(user, "tenant"),
|
return_value=(user, "tenant"),
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
|
|||||||
|
|
||||||
|
|
||||||
class TestRagPipelineImportCheckDependenciesApi:
|
class TestRagPipelineImportCheckDependenciesApi:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, flask_app_with_containers):
|
||||||
|
return flask_app_with_containers
|
||||||
|
|
||||||
def test_get_success(self, app):
|
def test_get_success(self, app):
|
||||||
api = RagPipelineImportCheckDependenciesApi()
|
api = RagPipelineImportCheckDependenciesApi()
|
||||||
method = unwrap(api.get)
|
method = unwrap(api.get)
|
||||||
@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.check_dependencies.return_value = result
|
service.check_dependencies.return_value = result
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
|
|||||||
|
|
||||||
|
|
||||||
class TestRagPipelineExportApi:
|
class TestRagPipelineExportApi:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, flask_app_with_containers):
|
||||||
|
return flask_app_with_containers
|
||||||
|
|
||||||
def test_get_with_include_secret(self, app):
|
def test_get_with_include_secret(self, app):
|
||||||
api = RagPipelineExportApi()
|
api = RagPipelineExportApi()
|
||||||
method = unwrap(api.get)
|
method = unwrap(api.get)
|
||||||
@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
|
|||||||
service = MagicMock()
|
service = MagicMock()
|
||||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||||
|
|
||||||
fake_db = MagicMock()
|
|
||||||
fake_db.engine = MagicMock()
|
|
||||||
|
|
||||||
session_ctx = MagicMock()
|
|
||||||
session_ctx.__enter__.return_value = MagicMock()
|
|
||||||
session_ctx.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/?include_secret=true"),
|
app.test_request_context("/?include_secret=true"),
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
|
||||||
fake_db,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
|
||||||
return_value=session_ctx,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||||
return_value=service,
|
return_value=service,
|
||||||
Loading…
Reference in New Issue
Block a user