From ad5bade45f2c5b82cff4a0f93062cfd7293b4c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Wed, 17 Jun 2026 18:11:37 +0800 Subject: [PATCH] fix(api): enforce document creation limits in pipeline generator (#37586) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../app/apps/pipeline/pipeline_generator.py | 4 ++ api/services/dataset_service.py | 32 +++++++------ .../apps/pipeline/test_pipeline_generator.py | 47 +++++++++++++++++++ 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 47b950ca08..255740b86a 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -138,6 +138,10 @@ class PipelineGenerator(BaseAppGenerator): documents: list[Document] = [] if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService + from services.feature_service import FeatureService + + features = FeatureService.get_features(pipeline.tenant_id) + DocumentService.check_document_creation_limits(len(datasource_info_list), features) for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index af50a0e318..364d9b36b9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2032,14 +2032,7 @@ class DocumentService: website_info = knowledge_config.data_source.info_list.website_info_list assert website_info count = len(website_info.urls) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - DocumentService.check_documents_upload_quota(count, features) + DocumentService.check_document_creation_limits(count, features) # if dataset is empty, update dataset data_source_type if not dataset.data_source_type and knowledge_config.data_source: @@ -2603,6 +2596,21 @@ class DocumentService: f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) + @staticmethod + def check_document_creation_limits(count: int, features: FeatureModel): + """Validate billing-backed document creation limits before document rows are created.""" + if not features.billing.enabled: + return + + if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) + @staticmethod def build_document( dataset: Dataset, @@ -2824,13 +2832,7 @@ class DocumentService: website_info = knowledge_config.data_source.info_list.website_info_list if website_info: count = len(website_info.urls) - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - DocumentService.check_documents_upload_quota(count, features) + DocumentService.check_document_creation_limits(count, features) dataset_collection_binding_id = None retrieval_model = None diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 06fd9e4806..67cea55711 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -173,6 +173,9 @@ def test_generate_published_pipeline_creates_documents_and_delay(generator, mock mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1) + features = SimpleNamespace() + mocker.patch("services.feature_service.FeatureService.get_features", return_value=features) + check_limits = mocker.patch("services.dataset_service.DocumentService.check_document_creation_limits") document1 = SimpleNamespace( id="doc1", @@ -226,9 +229,53 @@ def test_generate_published_pipeline_creates_documents_and_delay(generator, mock assert result["batch"] assert len(result["documents"]) == 2 + check_limits.assert_called_once_with(len(datasource_info_list), features) task_proxy.delay.assert_called_once() +def test_generate_published_pipeline_rejects_when_document_creation_limits_exceeded(generator, mocker: MockerFixture): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + datasource_info_list = [{"name": "file1"}, {"name": "file2"}] + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=datasource_info_list, + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + + features = SimpleNamespace() + mocker.patch("services.feature_service.FeatureService.get_features", return_value=features) + check_limits = mocker.patch( + "services.dataset_service.DocumentService.check_document_creation_limits", + side_effect=ValueError("document limit exceeded"), + ) + + db_session = MagicMock() + mocker.patch.object(module.db, "session", db_session) + + with pytest.raises(ValueError, match="document limit exceeded"): + generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=False, + ) + + check_limits.assert_called_once_with(len(datasource_info_list), features) + db_session.add.assert_not_called() + + def test_generate_is_retry_calls_generate(generator, mocker: MockerFixture): pipeline = _build_pipeline() workflow = _build_workflow()