test: migrate rag pipeline workflow controller tests to testcontainers (#34306)

This commit is contained in:
YBoy 2026-03-31 07:58:14 +03:00 committed by GitHub
parent 9b7b432e08
commit cc68f0e640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,13 @@
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
from __future__ import annotations
from datetime import datetime from datetime import datetime
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
import services import services
@ -38,6 +44,10 @@ def unwrap(func):
class TestDraftWorkflowApi: class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_draft_success(self, app): def test_get_draft_success(self, app):
api = DraftRagPipelineApi() api = DraftRagPipelineApi()
method = unwrap(api.get) method = unwrap(api.get)
@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes: class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_iteration_node_success(self, app): def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi() api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post) method = unwrap(api.post)
@ -275,6 +289,10 @@ class TestDraftRunNodes:
class TestPipelineRunApis: class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_draft_run_success(self, app): def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi() api = DraftRagPipelineRunApi()
method = unwrap(api.post) method = unwrap(api.post)
@ -337,6 +355,10 @@ class TestPipelineRunApis:
class TestDraftNodeRun: class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_execution_not_found(self, app): def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi() api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post) method = unwrap(api.post)
@ -364,45 +386,43 @@ class TestDraftNodeRun:
class TestPublishedPipelineApis: class TestPublishedPipelineApis:
def test_publish_success(self, app): @pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi() api = PublishedRagPipelineApi()
method = unwrap(api.post) method = unwrap(api.post)
pipeline = MagicMock() tenant_id = str(uuid4())
pipeline = Pipeline(
tenant_id=tenant_id,
name="test-pipeline",
description="test",
created_by=str(uuid4()),
)
db_session_with_containers.add(pipeline)
db_session_with_containers.commit()
db_session_with_containers.expire_all()
user = MagicMock(id="u1") user = MagicMock(id="u1")
workflow = MagicMock( workflow = MagicMock(
id="w1", id=str(uuid4()),
created_at=naive_utc_now(), created_at=naive_utc_now(),
) )
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock() service = MagicMock()
service.publish_workflow.return_value = workflow service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with ( with (
app.test_request_context("/"), app.test_request_context("/"),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"), return_value=(user, "t"),
), ),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service, return_value=service,
@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
class TestMiscApis: class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_task_stop(self, app): def test_task_stop(self, app):
api = RagPipelineTaskStopApi() api = RagPipelineTaskStopApi()
method = unwrap(api.post) method = unwrap(api.post)
@ -471,6 +495,10 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi: class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_published_run_success(self, app): def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi() api = PublishedRagPipelineRunApi()
method = unwrap(api.post) method = unwrap(api.post)
@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
class TestDefaultBlockConfigApi: class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_block_config_success(self, app): def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi() api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get) method = unwrap(api.get)
@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi: class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_published_workflows_success(self, app): def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi() api = PublishedAllRagPipelineApi()
method = unwrap(api.get) method = unwrap(api.get)
@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
service = MagicMock() service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with ( with (
app.test_request_context("/"), app.test_request_context("/"),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"), return_value=(user, "t"),
), ),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service, return_value=service,
@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
class TestRagPipelineByIdApi: class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_patch_success(self, app): def test_patch_success(self, app):
api = RagPipelineByIdApi() api = RagPipelineByIdApi()
method = unwrap(api.patch) method = unwrap(api.patch)
@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
service = MagicMock() service = MagicMock()
service.update_workflow.return_value = workflow service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"} payload = {"marked_name": "test"}
with ( with (
@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"), return_value=(user, "t"),
), ),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service, return_value=service,
@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
workflow_service = MagicMock() workflow_service = MagicMock()
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with ( with (
app.test_request_context("/", method="DELETE"), app.test_request_context("/", method="DELETE"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch( patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
return_value=workflow_service, return_value=workflow_service,
@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
): ):
result = method(api, pipeline, "old-workflow") result = method(api, pipeline, "old-workflow")
workflow_service.delete_workflow.assert_called_once_with( workflow_service.delete_workflow.assert_called_once()
session=session,
workflow_id="old-workflow",
tenant_id="t1",
)
session.commit.assert_called_once()
assert result == (None, 204) assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app): def test_delete_active_workflow_rejected(self, app):
@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_last_run_success(self, app): def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi() api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get) method = unwrap(api.get)
@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi: class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_set_datasource_variables_success(self, app): def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi() api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post) method = unwrap(api.post)