mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 22:11:09 +08:00
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8d05185e39
commit
1b2d397fc9
@ -994,7 +994,7 @@ class RagPipelineTransformApi(Resource):
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
rag_pipeline_transform_service = RagPipelineTransformService()
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str, db.session)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from uuid import uuid4
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
@ -28,8 +29,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
def transform_dataset(self, dataset_id: str, session: scoped_session):
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
|
||||
@ -94,9 +95,9 @@ class RagPipelineTransformService:
|
||||
dataset.pipeline_id = pipeline.id
|
||||
|
||||
# deal document data
|
||||
self._deal_document_data(dataset)
|
||||
self._deal_document_data(dataset, session)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return {
|
||||
"pipeline_id": pipeline.id,
|
||||
"dataset_id": dataset_id,
|
||||
@ -310,13 +311,13 @@ class RagPipelineTransformService:
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
def _deal_document_data(self, dataset: Dataset):
|
||||
def _deal_document_data(self, dataset: Dataset, session: scoped_session):
|
||||
file_node_id = "1752479895761"
|
||||
notion_node_id = "1752489759475"
|
||||
jina_node_id = "1752491761974"
|
||||
firecrawl_node_id = "1752565402678"
|
||||
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
|
||||
documents = session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
|
||||
|
||||
for document in documents:
|
||||
data_source_info_dict = document.data_source_info_dict
|
||||
@ -326,7 +327,7 @@ class RagPipelineTransformService:
|
||||
document.data_source_type = DataSourceType.LOCAL_FILE
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.get(UploadFile, file_id)
|
||||
file = session.get(UploadFile, file_id)
|
||||
if file:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
@ -350,8 +351,8 @@ class RagPipelineTransformService:
|
||||
datasource_node_id=file_node_id,
|
||||
)
|
||||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
session.add(document)
|
||||
session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
|
||||
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
|
||||
data_source_info = json.dumps(
|
||||
@ -378,8 +379,8 @@ class RagPipelineTransformService:
|
||||
datasource_node_id=notion_node_id,
|
||||
)
|
||||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
session.add(document)
|
||||
session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
@ -406,5 +407,5 @@ class RagPipelineTransformService:
|
||||
datasource_node_id=datasource_node_id,
|
||||
)
|
||||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
session.add(document)
|
||||
session.add(document_pipeline_execution_log)
|
||||
|
||||
@ -67,7 +67,7 @@ def test_deal_file_extensions_returns_original_when_empty() -> None:
|
||||
assert result is node
|
||||
|
||||
|
||||
def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None:
|
||||
def test_deal_dependencies_installs_missing_marketplace_plugins(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
installer_cls = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller")
|
||||
@ -92,7 +92,7 @@ def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None:
|
||||
install_mock.assert_called_once_with("tenant-1", ["missing-plugin:1.0.0"])
|
||||
|
||||
|
||||
def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None:
|
||||
def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.current_user",
|
||||
@ -142,7 +142,7 @@ def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None
|
||||
# --- transform_dataset ---
|
||||
|
||||
|
||||
def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None:
|
||||
def test_transform_dataset_returns_early_when_pipeline_exists(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -151,30 +151,21 @@ def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None:
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
result = service.transform_dataset("d1", session_mock)
|
||||
|
||||
assert result == {"pipeline_id": "p1", "dataset_id": "d1", "status": "success"}
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_dataset_not_found(mocker) -> None:
|
||||
def test_transform_dataset_raises_for_dataset_not_found(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
service.transform_dataset("d1")
|
||||
service.transform_dataset("d1", session_mock)
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_external_dataset(mocker) -> None:
|
||||
def test_transform_dataset_raises_for_external_dataset(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -184,16 +175,12 @@ def test_transform_dataset_raises_for_external_dataset(mocker) -> None:
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="External dataset is not supported"):
|
||||
service.transform_dataset("d1")
|
||||
service.transform_dataset("d1", session_mock)
|
||||
|
||||
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> None:
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -205,20 +192,16 @@ def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> No
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
|
||||
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
result = service.transform_dataset("d1", session_mock)
|
||||
|
||||
assert result == empty_result
|
||||
|
||||
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None:
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -231,15 +214,11 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
|
||||
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
result = service.transform_dataset("d1", session_mock)
|
||||
|
||||
assert result == empty_result
|
||||
|
||||
@ -247,7 +226,7 @@ def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None
|
||||
# --- _deal_knowledge_index ---
|
||||
|
||||
|
||||
def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None:
|
||||
def test_deal_knowledge_index_high_quality_sets_embedding(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = cast(
|
||||
Dataset,
|
||||
@ -299,7 +278,7 @@ def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None:
|
||||
# --- _deal_document_data ---
|
||||
|
||||
|
||||
def test_deal_document_data_notion(mocker) -> None:
|
||||
def test_deal_document_data_notion(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
doc = SimpleNamespace(
|
||||
@ -324,12 +303,8 @@ def test_deal_document_data_notion(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
service._deal_document_data(cast(Dataset, dataset), session_mock)
|
||||
|
||||
assert doc.data_source_type == "online_document"
|
||||
assert "page1" in doc.data_source_info
|
||||
@ -337,7 +312,7 @@ def test_deal_document_data_notion(mocker) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("provider", "node_id"), [("firecrawl", "1752565402678"), ("jinareader", "1752491761974")])
|
||||
def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None:
|
||||
def test_deal_document_data_website(mocker: MockerFixture, provider: str, node_id: str) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
doc = SimpleNamespace(
|
||||
@ -359,12 +334,8 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
service._deal_document_data(cast(Dataset, dataset), session_mock)
|
||||
|
||||
assert doc.data_source_type == "website_crawl"
|
||||
assert "example.com" in doc.data_source_info
|
||||
@ -376,7 +347,7 @@ def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None
|
||||
# --- transform_dataset complex flow ---
|
||||
|
||||
|
||||
def test_transform_dataset_full_flow(mocker) -> None:
|
||||
def test_transform_dataset_full_flow(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -398,10 +369,6 @@ def test_transform_dataset_full_flow(mocker) -> None:
|
||||
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
mocker.patch.object(service, "_deal_dependencies")
|
||||
mocker.patch.object(service, "_deal_document_data")
|
||||
@ -414,14 +381,14 @@ def test_transform_dataset_full_flow(mocker) -> None:
|
||||
pipeline = SimpleNamespace(id="p-new")
|
||||
mocker.patch.object(service, "_create_pipeline", return_value=pipeline)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
result = service.transform_dataset("d1", session_mock)
|
||||
|
||||
assert result["pipeline_id"] == "p-new"
|
||||
assert dataset.runtime_mode == "rag_pipeline"
|
||||
assert dataset.chunk_structure == "text_model"
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker) -> None:
|
||||
def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -447,10 +414,10 @@ def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create
|
||||
mocker.patch.object(service, "_create_pipeline", return_value=SimpleNamespace(id="p-new"))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported doc form"):
|
||||
service.transform_dataset("d1")
|
||||
service.transform_dataset("d1", session_mock)
|
||||
|
||||
|
||||
def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -> None:
|
||||
def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
@ -467,15 +434,11 @@ def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
mocker.patch.object(service, "_get_transform_yaml", return_value={})
|
||||
mocker.patch.object(service, "_deal_dependencies")
|
||||
|
||||
with pytest.raises(ValueError, match="Missing workflow data for rag pipeline"):
|
||||
service.transform_dataset("d1")
|
||||
service.transform_dataset("d1", session_mock)
|
||||
|
||||
|
||||
def test_create_pipeline_raises_when_workflow_data_missing() -> None:
|
||||
@ -485,7 +448,7 @@ def test_create_pipeline_raises_when_workflow_data_missing() -> None:
|
||||
service._create_pipeline({"rag_pipeline": {"name": "N"}})
|
||||
|
||||
|
||||
def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
|
||||
def test_deal_document_data_upload_file_with_existing_file(mocker: MockerFixture) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
document = SimpleNamespace(
|
||||
@ -506,12 +469,8 @@ def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
session_mock.get.return_value = upload_file
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
service._deal_document_data(cast(Dataset, dataset), session_mock)
|
||||
|
||||
assert document.data_source_type == "local_file"
|
||||
assert "real_file_id" in document.data_source_info
|
||||
@ -522,7 +481,9 @@ def _make_service():
|
||||
return RagPipelineTransformService.__new__(RagPipelineTransformService)
|
||||
|
||||
|
||||
def test_deal_dependencies_skips_marketplace_when_disabled(mocker: MockerFixture, caplog) -> None:
|
||||
def test_deal_dependencies_skips_marketplace_when_disabled(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.dify_config.MARKETPLACE_ENABLED",
|
||||
False,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user