diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index e9d26974c1..7c941e1436 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -994,7 +994,7 @@ class RagPipelineTransformApi(Resource): dataset_id_str = str(dataset_id) rag_pipeline_transform_service = RagPipelineTransformService() - result = rag_pipeline_transform_service.transform_dataset(dataset_id_str) + result = rag_pipeline_transform_service.transform_dataset(dataset_id_str, db.session) return result diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index ca755d0b91..dc3eeae201 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -8,6 +8,7 @@ from uuid import uuid4 import yaml from flask_login import current_user from sqlalchemy import select +from sqlalchemy.orm import scoped_session from configs import dify_config from constants import DOCUMENT_EXTENSIONS @@ -28,8 +29,8 @@ logger = logging.getLogger(__name__) class RagPipelineTransformService: - def transform_dataset(self, dataset_id: str): - dataset = db.session.get(Dataset, dataset_id) + def transform_dataset(self, dataset_id: str, session: scoped_session): + dataset = session.get(Dataset, dataset_id) if not dataset: raise ValueError("Dataset not found") if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: @@ -94,9 +95,9 @@ class RagPipelineTransformService: dataset.pipeline_id = pipeline.id # deal document data - self._deal_document_data(dataset) + self._deal_document_data(dataset, session) - db.session.commit() + session.commit() return { "pipeline_id": pipeline.id, "dataset_id": dataset_id, @@ -310,13 +311,13 @@ class RagPipelineTransformService: "status": "success", } - def _deal_document_data(self, dataset: Dataset): + def _deal_document_data(self, dataset: Dataset, session: scoped_session): file_node_id = "1752479895761" notion_node_id = "1752489759475" jina_node_id = "1752491761974" firecrawl_node_id = "1752565402678" - documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() + documents = session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() for document in documents: data_source_info_dict = document.data_source_info_dict @@ -326,7 +327,7 @@ class RagPipelineTransformService: document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: - file = db.session.get(UploadFile, file_id) + file = session.get(UploadFile, file_id) if file: data_source_info = json.dumps( { @@ -350,8 +351,8 @@ class RagPipelineTransformService: datasource_node_id=file_node_id, ) document_pipeline_execution_log.created_at = document.created_at - db.session.add(document) - db.session.add(document_pipeline_execution_log) + session.add(document) + session.add(document_pipeline_execution_log) elif document.data_source_type == DataSourceType.NOTION_IMPORT: document.data_source_type = DataSourceType.ONLINE_DOCUMENT data_source_info = json.dumps( @@ -378,8 +379,8 @@ class RagPipelineTransformService: datasource_node_id=notion_node_id, ) document_pipeline_execution_log.created_at = document.created_at - db.session.add(document) - db.session.add(document_pipeline_execution_log) + session.add(document) + session.add(document_pipeline_execution_log) elif document.data_source_type == DataSourceType.WEBSITE_CRAWL: data_source_info = json.dumps( { @@ -406,5 +407,5 @@ class RagPipelineTransformService: datasource_node_id=datasource_node_id, ) document_pipeline_execution_log.created_at = document.created_at - db.session.add(document) - db.session.add(document_pipeline_execution_log) + session.add(document) + session.add(document_pipeline_execution_log) diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_transform_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_transform_service.py index 3f511a109a..f6a3f524fe 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_transform_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_transform_service.py @@ -67,7 +67,7 @@ def test_deal_file_extensions_returns_original_when_empty() -> None: assert result is node -def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None: +def test_deal_dependencies_installs_missing_marketplace_plugins(mocker: MockerFixture) -> None: service = RagPipelineTransformService() installer_cls = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller") @@ -92,7 +92,7 @@ def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None: install_mock.assert_called_once_with("tenant-1", ["missing-plugin:1.0.0"]) -def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None: +def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker: MockerFixture) -> None: service = RagPipelineTransformService() mocker.patch( "services.rag_pipeline.rag_pipeline_transform_service.current_user", @@ -142,7 +142,7 @@ def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None # --- transform_dataset --- -def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None: +def test_transform_dataset_returns_early_when_pipeline_exists(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -151,30 +151,21 @@ def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None: ) session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) - result = service.transform_dataset("d1") + result = service.transform_dataset("d1", session_mock) assert result == {"pipeline_id": "p1", "dataset_id": "d1", "status": "success"} -def test_transform_dataset_raises_for_dataset_not_found(mocker) -> None: +def test_transform_dataset_raises_for_dataset_not_found(mocker: MockerFixture) -> None: service = RagPipelineTransformService() session_mock = mocker.Mock() session_mock.get.return_value = None - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) - with pytest.raises(ValueError, match="Dataset not found"): - service.transform_dataset("d1") + service.transform_dataset("d1", session_mock) -def test_transform_dataset_raises_for_external_dataset(mocker) -> None: +def test_transform_dataset_raises_for_external_dataset(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -184,16 +175,12 @@ def test_transform_dataset_raises_for_external_dataset(mocker) -> None: ) session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) with pytest.raises(ValueError, match="External dataset is not supported"): - service.transform_dataset("d1") + service.transform_dataset("d1", session_mock) -def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> None: +def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -205,20 +192,16 @@ def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> No ) session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"} mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result) - result = service.transform_dataset("d1") + result = service.transform_dataset("d1", session_mock) assert result == empty_result -def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None: +def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -231,15 +214,11 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None ) session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"} mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result) - result = service.transform_dataset("d1") + result = service.transform_dataset("d1", session_mock) assert result == empty_result @@ -247,7 +226,7 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None # --- _deal_knowledge_index --- -def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None: +def test_deal_knowledge_index_high_quality_sets_embedding(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = cast( Dataset, @@ -299,7 +278,7 @@ def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None: # --- _deal_document_data --- -def test_deal_document_data_notion(mocker) -> None: +def test_deal_document_data_notion(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace(id="d1", pipeline_id="p1") doc = SimpleNamespace( @@ -324,12 +303,8 @@ def test_deal_document_data_notion(mocker) -> None: session_mock = mocker.Mock() session_mock.scalars.return_value = scalars_mock add_mock = session_mock.add - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) - service._deal_document_data(cast(Dataset, dataset)) + service._deal_document_data(cast(Dataset, dataset), session_mock) assert doc.data_source_type == "online_document" assert "page1" in doc.data_source_info @@ -337,7 +312,7 @@ def test_deal_document_data_notion(mocker) -> None: @pytest.mark.parametrize(("provider", "node_id"), [("firecrawl", "1752565402678"), ("jinareader", "1752491761974")]) -def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None: +def test_deal_document_data_website(mocker: MockerFixture, provider: str, node_id: str) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace(id="d1", pipeline_id="p1") doc = SimpleNamespace( @@ -359,12 +334,8 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None session_mock = mocker.Mock() session_mock.scalars.return_value = scalars_mock add_mock = session_mock.add - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) - service._deal_document_data(cast(Dataset, dataset)) + service._deal_document_data(cast(Dataset, dataset), session_mock) assert doc.data_source_type == "website_crawl" assert "example.com" in doc.data_source_info @@ -376,7 +347,7 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None # --- transform_dataset complex flow --- -def test_transform_dataset_full_flow(mocker) -> None: +def test_transform_dataset_full_flow(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -398,10 +369,6 @@ def test_transform_dataset_full_flow(mocker) -> None: session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) mocker.patch.object(service, "_deal_dependencies") mocker.patch.object(service, "_deal_document_data") @@ -414,14 +381,14 @@ def test_transform_dataset_full_flow(mocker) -> None: pipeline = SimpleNamespace(id="p-new") mocker.patch.object(service, "_create_pipeline", return_value=pipeline) - result = service.transform_dataset("d1") + result = service.transform_dataset("d1", session_mock) assert result["pipeline_id"] == "p-new" assert dataset.runtime_mode == "rag_pipeline" assert dataset.chunk_structure == "text_model" -def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker) -> None: +def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -447,10 +414,10 @@ def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create mocker.patch.object(service, "_create_pipeline", return_value=SimpleNamespace(id="p-new")) with pytest.raises(ValueError, match="Unsupported doc form"): - service.transform_dataset("d1") + service.transform_dataset("d1", session_mock) -def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -> None: +def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace( id="d1", @@ -467,15 +434,11 @@ def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) - ) session_mock = mocker.Mock() session_mock.get.return_value = dataset - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) mocker.patch.object(service, "_get_transform_yaml", return_value={}) mocker.patch.object(service, "_deal_dependencies") with pytest.raises(ValueError, match="Missing workflow data for rag pipeline"): - service.transform_dataset("d1") + service.transform_dataset("d1", session_mock) def test_create_pipeline_raises_when_workflow_data_missing() -> None: @@ -485,7 +448,7 @@ def test_create_pipeline_raises_when_workflow_data_missing() -> None: service._create_pipeline({"rag_pipeline": {"name": "N"}}) -def test_deal_document_data_upload_file_with_existing_file(mocker) -> None: +def test_deal_document_data_upload_file_with_existing_file(mocker: MockerFixture) -> None: service = RagPipelineTransformService() dataset = SimpleNamespace(id="d1", pipeline_id="p1") document = SimpleNamespace( @@ -506,12 +469,8 @@ def test_deal_document_data_upload_file_with_existing_file(mocker) -> None: session_mock.scalars.return_value = scalars_mock session_mock.get.return_value = upload_file add_mock = session_mock.add - mocker.patch( - "services.rag_pipeline.rag_pipeline_transform_service.db", - new=SimpleNamespace(session=session_mock), - ) - service._deal_document_data(cast(Dataset, dataset)) + service._deal_document_data(cast(Dataset, dataset), session_mock) assert document.data_source_type == "local_file" assert "real_file_id" in document.data_source_info @@ -522,7 +481,9 @@ def _make_service(): return RagPipelineTransformService.__new__(RagPipelineTransformService) -def test_deal_dependencies_skips_marketplace_when_disabled(mocker: MockerFixture, caplog) -> None: +def test_deal_dependencies_skips_marketplace_when_disabled( + mocker: MockerFixture, caplog: pytest.LogCaptureFixture +) -> None: mocker.patch( "services.rag_pipeline.rag_pipeline_transform_service.dify_config.MARKETPLACE_ENABLED", False,