test: example of make db.session pass from parameter. #37403 (#37471)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-16 13:07:01 +09:00 committed by GitHub
parent 8d05185e39
commit 1b2d397fc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 81 deletions

View File

@ -994,7 +994,7 @@ class RagPipelineTransformApi(Resource):
dataset_id_str = str(dataset_id)
rag_pipeline_transform_service = RagPipelineTransformService()
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str, db.session)
return result

View File

@ -8,6 +8,7 @@ from uuid import uuid4
import yaml
from flask_login import current_user
from sqlalchemy import select
from sqlalchemy.orm import scoped_session
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
@ -28,8 +29,8 @@ logger = logging.getLogger(__name__)
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.get(Dataset, dataset_id)
def transform_dataset(self, dataset_id: str, session: scoped_session):
dataset = session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
@ -94,9 +95,9 @@ class RagPipelineTransformService:
dataset.pipeline_id = pipeline.id
# deal document data
self._deal_document_data(dataset)
self._deal_document_data(dataset, session)
db.session.commit()
session.commit()
return {
"pipeline_id": pipeline.id,
"dataset_id": dataset_id,
@ -310,13 +311,13 @@ class RagPipelineTransformService:
"status": "success",
}
def _deal_document_data(self, dataset: Dataset):
def _deal_document_data(self, dataset: Dataset, session: scoped_session):
file_node_id = "1752479895761"
notion_node_id = "1752489759475"
jina_node_id = "1752491761974"
firecrawl_node_id = "1752565402678"
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
documents = session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
for document in documents:
data_source_info_dict = document.data_source_info_dict
@ -326,7 +327,7 @@ class RagPipelineTransformService:
document.data_source_type = DataSourceType.LOCAL_FILE
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.get(UploadFile, file_id)
file = session.get(UploadFile, file_id)
if file:
data_source_info = json.dumps(
{
@ -350,8 +351,8 @@ class RagPipelineTransformService:
datasource_node_id=file_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
session.add(document)
session.add(document_pipeline_execution_log)
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
data_source_info = json.dumps(
@ -378,8 +379,8 @@ class RagPipelineTransformService:
datasource_node_id=notion_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
session.add(document)
session.add(document_pipeline_execution_log)
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
data_source_info = json.dumps(
{
@ -406,5 +407,5 @@ class RagPipelineTransformService:
datasource_node_id=datasource_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
session.add(document)
session.add(document_pipeline_execution_log)

View File

@ -67,7 +67,7 @@ def test_deal_file_extensions_returns_original_when_empty() -> None:
assert result is node
def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None:
def test_deal_dependencies_installs_missing_marketplace_plugins(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
installer_cls = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller")
@ -92,7 +92,7 @@ def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None:
install_mock.assert_called_once_with("tenant-1", ["missing-plugin:1.0.0"])
def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None:
def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.current_user",
@ -142,7 +142,7 @@ def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None
# --- transform_dataset ---
def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None:
def test_transform_dataset_returns_early_when_pipeline_exists(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -151,30 +151,21 @@ def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None:
)
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
result = service.transform_dataset("d1")
result = service.transform_dataset("d1", session_mock)
assert result == {"pipeline_id": "p1", "dataset_id": "d1", "status": "success"}
def test_transform_dataset_raises_for_dataset_not_found(mocker) -> None:
def test_transform_dataset_raises_for_dataset_not_found(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
session_mock = mocker.Mock()
session_mock.get.return_value = None
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
with pytest.raises(ValueError, match="Dataset not found"):
service.transform_dataset("d1")
service.transform_dataset("d1", session_mock)
def test_transform_dataset_raises_for_external_dataset(mocker) -> None:
def test_transform_dataset_raises_for_external_dataset(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -184,16 +175,12 @@ def test_transform_dataset_raises_for_external_dataset(mocker) -> None:
)
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
with pytest.raises(ValueError, match="External dataset is not supported"):
service.transform_dataset("d1")
service.transform_dataset("d1", session_mock)
def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> None:
def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -205,20 +192,16 @@ def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> No
)
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
result = service.transform_dataset("d1")
result = service.transform_dataset("d1", session_mock)
assert result == empty_result
def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None:
def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -231,15 +214,11 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None
)
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
result = service.transform_dataset("d1")
result = service.transform_dataset("d1", session_mock)
assert result == empty_result
@ -247,7 +226,7 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None
# --- _deal_knowledge_index ---
def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None:
def test_deal_knowledge_index_high_quality_sets_embedding(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = cast(
Dataset,
@ -299,7 +278,7 @@ def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None:
# --- _deal_document_data ---
def test_deal_document_data_notion(mocker) -> None:
def test_deal_document_data_notion(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
doc = SimpleNamespace(
@ -324,12 +303,8 @@ def test_deal_document_data_notion(mocker) -> None:
session_mock = mocker.Mock()
session_mock.scalars.return_value = scalars_mock
add_mock = session_mock.add
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
service._deal_document_data(cast(Dataset, dataset))
service._deal_document_data(cast(Dataset, dataset), session_mock)
assert doc.data_source_type == "online_document"
assert "page1" in doc.data_source_info
@ -337,7 +312,7 @@ def test_deal_document_data_notion(mocker) -> None:
@pytest.mark.parametrize(("provider", "node_id"), [("firecrawl", "1752565402678"), ("jinareader", "1752491761974")])
def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None:
def test_deal_document_data_website(mocker: MockerFixture, provider: str, node_id: str) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
doc = SimpleNamespace(
@ -359,12 +334,8 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None
session_mock = mocker.Mock()
session_mock.scalars.return_value = scalars_mock
add_mock = session_mock.add
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
service._deal_document_data(cast(Dataset, dataset))
service._deal_document_data(cast(Dataset, dataset), session_mock)
assert doc.data_source_type == "website_crawl"
assert "example.com" in doc.data_source_info
@ -376,7 +347,7 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None
# --- transform_dataset complex flow ---
def test_transform_dataset_full_flow(mocker) -> None:
def test_transform_dataset_full_flow(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -398,10 +369,6 @@ def test_transform_dataset_full_flow(mocker) -> None:
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
mocker.patch.object(service, "_deal_dependencies")
mocker.patch.object(service, "_deal_document_data")
@ -414,14 +381,14 @@ def test_transform_dataset_full_flow(mocker) -> None:
pipeline = SimpleNamespace(id="p-new")
mocker.patch.object(service, "_create_pipeline", return_value=pipeline)
result = service.transform_dataset("d1")
result = service.transform_dataset("d1", session_mock)
assert result["pipeline_id"] == "p-new"
assert dataset.runtime_mode == "rag_pipeline"
assert dataset.chunk_structure == "text_model"
def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker) -> None:
def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -447,10 +414,10 @@ def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create
mocker.patch.object(service, "_create_pipeline", return_value=SimpleNamespace(id="p-new"))
with pytest.raises(ValueError, match="Unsupported doc form"):
service.transform_dataset("d1")
service.transform_dataset("d1", session_mock)
def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -> None:
def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(
id="d1",
@ -467,15 +434,11 @@ def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -
)
session_mock = mocker.Mock()
session_mock.get.return_value = dataset
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
mocker.patch.object(service, "_get_transform_yaml", return_value={})
mocker.patch.object(service, "_deal_dependencies")
with pytest.raises(ValueError, match="Missing workflow data for rag pipeline"):
service.transform_dataset("d1")
service.transform_dataset("d1", session_mock)
def test_create_pipeline_raises_when_workflow_data_missing() -> None:
@ -485,7 +448,7 @@ def test_create_pipeline_raises_when_workflow_data_missing() -> None:
service._create_pipeline({"rag_pipeline": {"name": "N"}})
def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
def test_deal_document_data_upload_file_with_existing_file(mocker: MockerFixture) -> None:
service = RagPipelineTransformService()
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
document = SimpleNamespace(
@ -506,12 +469,8 @@ def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
session_mock.scalars.return_value = scalars_mock
session_mock.get.return_value = upload_file
add_mock = session_mock.add
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.db",
new=SimpleNamespace(session=session_mock),
)
service._deal_document_data(cast(Dataset, dataset))
service._deal_document_data(cast(Dataset, dataset), session_mock)
assert document.data_source_type == "local_file"
assert "real_file_id" in document.data_source_info
@ -522,7 +481,9 @@ def _make_service():
return RagPipelineTransformService.__new__(RagPipelineTransformService)
def test_deal_dependencies_skips_marketplace_when_disabled(mocker: MockerFixture, caplog) -> None:
def test_deal_dependencies_skips_marketplace_when_disabled(
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
) -> None:
mocker.patch(
"services.rag_pipeline.rag_pipeline_transform_service.dify_config.MARKETPLACE_ENABLED",
False,