From 36e6205c6c37f0f53685173aeb29a2437f1b3223 Mon Sep 17 00:00:00 2001 From: longbingljw Date: Fri, 7 Nov 2025 17:25:58 +0800 Subject: [PATCH] feat:update latest commit (#51) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: adding some web tests (#27792) * feat: add validation to prevent saving empty opening statement in conversation opener modal (#27843) * fix(web): improve the consistency of the inputs-form UI (#27837) * fix(web): increase z-index of PortalToFollowElemContent (#27823) * fix: installation_id is missing when in tools page (#27849) * fix: avoid passing empty uniqueIdentifier to InstallFromMarketplace (#27802) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * test: create new test scripts and update some existing test scripts o… (#27850) * feat: change feedback to forum (#27862) * chore: translate i18n files and update type definitions (#27868) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> * Fix/template transformer line number (#27867) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> * bump vite to 6.4.1 (#27877) * Add WEAVIATE_GRPC_ENDPOINT as designed in weaviate migration guide (#27861) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> * Fix: correct DraftWorkflowApi.post response model (#27289) Signed-off-by: Yongtao Huang Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> * fix Version 2.0.0-beta.2: Chat annotations Api Error #25506 (#27206) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato * fix jina reader creadential migration command (#27883) * fix agent putout the output of workflow-tool twice (#26835) (#27087) * fix jina reader transform (#27922) * fix: prevent fetch version info in enterprise edition (#27923) * fix(api): fix `VariablePool.get` adding unexpected keys to variable_dictionary (#26767) Co-authored-by: -LAN- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: implement tenant self queue for rag tasks (#27559) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- * fix: bump brotli to 1.2.0 resloved CVE-2025-6176 (#27950) Signed-off-by: kenwoodjw --------- Signed-off-by: Yongtao Huang Signed-off-by: kenwoodjw Co-authored-by: aka James4u Co-authored-by: Novice Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com> Co-authored-by: Elliott <105957288+Elliott-byte@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: johnny0120 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Gritty_dev <101377478+codomposer@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: wangjifeng <163279492+kk-wangjifeng@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Boris Polonsky Co-authored-by: Yongtao Huang Co-authored-by: Cursx <33718736+Cursx@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Asuka Minato Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: red_sun <56100962+redSun64@users.noreply.github.com> Co-authored-by: NFish Co-authored-by: QuantumGhost Co-authored-by: -LAN- Co-authored-by: hj24 Co-authored-by: kenwoodjw --- api/.env.example | 3 + api/commands.py | 2 +- api/configs/feature/__init__.py | 8 + api/configs/middleware/vdb/weaviate_config.py | 5 + api/controllers/console/app/workflow.py | 13 +- api/controllers/service_api/wraps.py | 24 +- .../app/apps/pipeline/pipeline_generator.py | 36 +- api/core/entities/document_task.py | 15 + .../javascript/javascript_transformer.py | 8 +- .../python3/python3_transformer.py | 4 +- .../vdb/weaviate/weaviate_vector.py | 22 +- api/core/rag/pipeline/__init__.py | 0 api/core/rag/pipeline/queue.py | 79 ++ api/core/tools/__base/tool.py | 5 +- api/core/tools/entities/tool_entities.py | 1 + api/core/tools/tool_engine.py | 3 + api/core/tools/workflow_as_tool/tool.py | 2 +- api/core/workflow/runtime/variable_pool.py | 6 +- api/docker/entrypoint.sh | 2 +- api/services/dataset_service.py | 4 +- api/services/document_indexing_task_proxy.py | 83 ++ .../rag_pipeline/rag_pipeline_task_proxy.py | 106 ++ .../website-crawl-general-economy.yml | 2 +- .../website-crawl-general-high-quality.yml | 2 +- .../transform/website-crawl-parentchild.yml | 2 +- api/tasks/document_indexing_task.py | 79 ++ .../priority_rag_pipeline_run_task.py | 25 + .../rag_pipeline/rag_pipeline_run_task.py | 34 +- .../core/rag/__init__.py | 1 + .../core/rag/pipeline/__init__.py | 0 .../rag/pipeline/test_queue_integration.py | 595 +++++++++++ .../tasks/test_document_indexing_task.py | 414 +++++++- .../tasks/test_rag_pipeline_run_tasks.py | 936 ++++++++++++++++++ .../core/helper/code_executor/__init__.py | 0 .../code_executor/javascript/__init__.py | 0 .../javascript/test_javascript_transformer.py | 12 + .../helper/code_executor/python3/__init__.py | 0 .../python3/test_python3_transformer.py | 12 + .../core/rag/pipeline/test_queue.py | 301 ++++++ .../workflow/entities/test_variable_pool.py | 23 + .../test_document_indexing_task_proxy.py | 317 ++++++ .../services/test_rag_pipeline_task_proxy.py | 483 +++++++++ api/uv.lock | 62 +- dev/start-worker | 2 +- docker/.env.example | 4 + docker/docker-compose.yaml | 2 + web/__tests__/check-i18n.test.ts | 100 ++ web/__tests__/navigation-utils.test.ts | 112 +++ .../components/app-sidebar/app-operations.tsx | 2 +- .../chat-with-history/inputs-form/content.tsx | 2 +- .../embedded-chatbot/inputs-form/content.tsx | 2 +- .../conversation-opener/modal.tsx | 11 +- .../commands/{feedback.tsx => forum.tsx} | 28 +- .../goto-anything/actions/commands/slash.tsx | 6 +- .../header/account-dropdown/support.tsx | 8 +- .../plugin-detail-panel/detail-header.tsx | 7 +- .../components/plugins/plugin-page/index.tsx | 8 +- .../share/text-generation/run-once/index.tsx | 11 +- .../components/before-run-form/form-item.tsx | 6 +- web/app/components/workflow/run/status.tsx | 2 +- web/context/app-context.tsx | 12 +- web/i18n/de-DE/app-debug.ts | 1 - web/i18n/de-DE/common.ts | 3 +- web/i18n/en-US/app-debug.ts | 1 - web/i18n/en-US/common.ts | 2 +- web/i18n/es-ES/app-debug.ts | 1 - web/i18n/es-ES/common.ts | 3 +- web/i18n/fa-IR/app-debug.ts | 1 - web/i18n/fa-IR/common.ts | 3 +- web/i18n/fr-FR/app-debug.ts | 1 - web/i18n/fr-FR/common.ts | 3 +- web/i18n/hi-IN/app-debug.ts | 1 - web/i18n/hi-IN/common.ts | 3 +- web/i18n/id-ID/app-debug.ts | 1 - web/i18n/id-ID/common.ts | 3 +- web/i18n/it-IT/app-debug.ts | 1 - web/i18n/it-IT/common.ts | 3 +- web/i18n/ja-JP/app-debug.ts | 1 - web/i18n/ja-JP/common.ts | 3 +- web/i18n/ko-KR/app-debug.ts | 1 - web/i18n/ko-KR/common.ts | 3 +- web/i18n/pl-PL/app-debug.ts | 1 - web/i18n/pl-PL/common.ts | 3 +- web/i18n/pt-BR/app-debug.ts | 1 - web/i18n/pt-BR/common.ts | 3 +- web/i18n/ro-RO/app-debug.ts | 1 - web/i18n/ro-RO/common.ts | 3 +- web/i18n/ru-RU/app-debug.ts | 1 - web/i18n/ru-RU/common.ts | 3 +- web/i18n/sl-SI/app-debug.ts | 1 - web/i18n/sl-SI/common.ts | 3 +- web/i18n/th-TH/app-debug.ts | 1 - web/i18n/th-TH/common.ts | 3 +- web/i18n/tr-TR/app-debug.ts | 1 - web/i18n/tr-TR/common.ts | 3 +- web/i18n/uk-UA/app-debug.ts | 1 - web/i18n/uk-UA/common.ts | 3 +- web/i18n/vi-VN/app-debug.ts | 1 - web/i18n/vi-VN/common.ts | 3 +- web/i18n/zh-Hans/app-debug.ts | 1 - web/i18n/zh-Hans/common.ts | 2 +- web/i18n/zh-Hant/app-debug.ts | 1 - web/i18n/zh-Hant/common.ts | 3 +- web/package.json | 4 +- web/pnpm-lock.yaml | 633 +----------- web/service/_tools_util.spec.ts | 36 + web/utils/app-redirection.spec.ts | 106 ++ web/utils/classnames.spec.ts | 101 ++ web/utils/clipboard.spec.ts | 109 ++ web/utils/completion-params.spec.ts | 230 +++++ web/utils/emoji.spec.ts | 77 ++ web/utils/format.spec.ts | 94 +- web/utils/get-icon.spec.ts | 49 + web/utils/index.spec.ts | 305 ++++++ web/utils/mcp.spec.ts | 88 ++ web/utils/navigation.spec.ts | 297 ++++++ web/utils/permission.spec.ts | 95 ++ web/utils/time.spec.ts | 99 ++ web/utils/tool-call.spec.ts | 79 ++ web/utils/urlValidation.spec.ts | 49 + web/utils/validators.spec.ts | 139 +++ web/utils/var.spec.ts | 236 +++++ 122 files changed, 6139 insertions(+), 825 deletions(-) create mode 100644 api/core/entities/document_task.py create mode 100644 api/core/rag/pipeline/__init__.py create mode 100644 api/core/rag/pipeline/queue.py create mode 100644 api/services/document_indexing_task_proxy.py create mode 100644 api/services/rag_pipeline/rag_pipeline_task_proxy.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/__init__.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py create mode 100644 api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py create mode 100644 api/tests/unit_tests/core/helper/code_executor/__init__.py create mode 100644 api/tests/unit_tests/core/helper/code_executor/javascript/__init__.py create mode 100644 api/tests/unit_tests/core/helper/code_executor/javascript/test_javascript_transformer.py create mode 100644 api/tests/unit_tests/core/helper/code_executor/python3/__init__.py create mode 100644 api/tests/unit_tests/core/helper/code_executor/python3/test_python3_transformer.py create mode 100644 api/tests/unit_tests/core/rag/pipeline/test_queue.py create mode 100644 api/tests/unit_tests/services/test_document_indexing_task_proxy.py create mode 100644 api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py rename web/app/components/goto-anything/actions/commands/{feedback.tsx => forum.tsx} (54%) create mode 100644 web/utils/app-redirection.spec.ts create mode 100644 web/utils/clipboard.spec.ts create mode 100644 web/utils/completion-params.spec.ts create mode 100644 web/utils/emoji.spec.ts create mode 100644 web/utils/get-icon.spec.ts create mode 100644 web/utils/mcp.spec.ts create mode 100644 web/utils/navigation.spec.ts create mode 100644 web/utils/permission.spec.ts create mode 100644 web/utils/time.spec.ts create mode 100644 web/utils/tool-call.spec.ts create mode 100644 web/utils/urlValidation.spec.ts create mode 100644 web/utils/validators.spec.ts create mode 100644 web/utils/var.spec.ts diff --git a/api/.env.example b/api/.env.example index a9097db139..8b82b76bde 100644 --- a/api/.env.example +++ b/api/.env.example @@ -634,5 +634,8 @@ SWAGGER_UI_PATH=/swagger-ui.html # Set to false to export dataset IDs as plain text for easier cross-environment import DSL_EXPORT_ENCRYPT_DATASET_ID=true +# Tenant isolated task queue configuration +TENANT_ISOLATED_TASK_CONCURRENCY=1 + # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 diff --git a/api/commands.py b/api/commands.py index 084fd576a1..8698ec3f97 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1601,7 +1601,7 @@ def transform_datasource_credentials(): "integration_secret": api_key, } datasource_provider = DatasourceProvider( - provider="jina", + provider="jinareader", tenant_id=tenant_id, plugin_id=jina_plugin_id, auth_type=api_key_credential_type.value, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 843e6b6f70..86c37dca25 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1142,6 +1142,13 @@ class SwaggerUIConfig(BaseSettings): ) +class TenantIsolatedTaskQueueConfig(BaseSettings): + TENANT_ISOLATED_TASK_CONCURRENCY: int = Field( + description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant", + default=1, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1166,6 +1173,7 @@ class FeatureConfig( RagEtlConfig, RepositoryConfig, SecurityConfig, + TenantIsolatedTaskQueueConfig, ToolConfig, UpdateConfig, WorkflowConfig, diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 6a79412ab8..aa81c870f6 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings): default=True, ) + WEAVIATE_GRPC_ENDPOINT: str | None = Field( + description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')", + default=None, + ) + WEAVIATE_BATCH_SIZE: PositiveInt = Field( description="Number of objects to be processed in a single batch operation (default is 100)", default=100, diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 56771ed420..5f41b65e88 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -102,7 +102,18 @@ class DraftWorkflowApi(Resource): }, ) ) - @api.response(200, "Draft workflow synced successfully", workflow_fields) + @api.response( + 200, + "Draft workflow synced successfully", + api.model( + "SyncDraftWorkflowResponse", + { + "result": fields.String, + "hash": fields.String, + "updated_at": fields.String, + }, + ), + ) @api.response(400, "Invalid workflow configuration") @api.response(403, "Permission denied") @edit_permission_required diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index fe1e2c419b..319b7bd780 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -67,6 +67,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe kwargs["app_model"] = app_model + # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: user_id = request.args.get("user") @@ -75,7 +76,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: user_id = request.form.get("user") else: - # use default-user user_id = None if not user_id and fetch_user_arg.required: @@ -90,6 +90,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # Set EndUser as current logged-in user for flask_login.current_user current_app.login_manager._update_request_context_with_user(end_user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore + else: + # For service API without end-user context, ensure an Account is logged in + # so services relying on current_account_with_tenant() work correctly. + tenant_owner_info = ( + db.session.query(Tenant, Account) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .join(Account, TenantAccountJoin.account_id == Account.id) + .where( + Tenant.id == app_model.tenant_id, + TenantAccountJoin.role == "owner", + Tenant.status == TenantStatus.NORMAL, + ) + .one_or_none() + ) + + if tenant_owner_info: + tenant_model, account = tenant_owner_info + account.current_tenant = tenant_model + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore + else: + raise Unauthorized("Tenant owner account not found or tenant is not active.") return view_func(*args, **kwargs) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c36d34f571..a1390ad0be 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -40,20 +40,15 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader -from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from extensions.ext_redis import redis_client from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.datasource_provider_service import DatasourceProviderService -from services.feature_service import FeatureService -from services.file_service import FileService +from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService -from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task -from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task logger = logging.getLogger(__name__) @@ -249,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator): ) if rag_pipeline_invoke_entities: - # store the rag_pipeline_invoke_entities to object storage - text = [item.model_dump() for item in rag_pipeline_invoke_entities] - name = "rag_pipeline_invoke_entities.json" - # Convert list to proper JSON string - json_text = json.dumps(text) - upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) - features = FeatureService.get_features(dataset.tenant_id) - if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX: - tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" - tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" - - if redis_client.get(tenant_pipeline_task_key): - # Add to waiting queue using List operations (lpush) - redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id) - else: - # Set flag and execute task - redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60) - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=upload_file.id, - tenant_id=dataset.tenant_id, - ) - - else: - priority_rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=upload_file.id, - tenant_id=dataset.tenant_id, - ) - + RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay() # return batch, dataset, documents return { "batch": batch, diff --git a/api/core/entities/document_task.py b/api/core/entities/document_task.py new file mode 100644 index 0000000000..27ab5c84f7 --- /dev/null +++ b/api/core/entities/document_task.py @@ -0,0 +1,15 @@ +from collections.abc import Sequence +from dataclasses import dataclass + + +@dataclass +class DocumentTask: + """Document task entity for document indexing operations. + + This class represents a document indexing task that can be queued + and processed by the document indexing system. + """ + + tenant_id: str + dataset_id: str + document_ids: Sequence[str] diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index 62489cdf29..e28f027a3a 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class NodeJsTemplateTransformer(TemplateTransformer): @classmethod def get_runner_script(cls) -> str: - runner_script = dedent( - f""" - // declare main function - {cls._code_placeholder} + runner_script = dedent(f""" {cls._code_placeholder} // decode and prepare input object var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8')) @@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer): var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """ - ) + """) return runner_script diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index 836fd273ae..ee866eeb81 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class Python3TemplateTransformer(TemplateTransformer): @classmethod def get_runner_script(cls) -> str: - runner_script = dedent(f""" - # declare main function - {cls._code_placeholder} + runner_script = dedent(f""" {cls._code_placeholder} import json from base64 import b64decode diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index dceade0af9..39eb91e50e 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel): Attributes: endpoint: Weaviate server endpoint URL + grpc_endpoint: Optional Weaviate gRPC server endpoint URL api_key: Optional API key for authentication batch_size: Number of objects to batch per insert operation """ endpoint: str + grpc_endpoint: str | None = None api_key: str | None = None batch_size: int = 100 @@ -88,9 +90,22 @@ class WeaviateVector(BaseVector): http_secure = p.scheme == "https" http_port = p.port or (443 if http_secure else 80) - grpc_host = host - grpc_secure = http_secure - grpc_port = 443 if grpc_secure else 50051 + # Parse gRPC configuration + if config.grpc_endpoint: + # Urls without scheme won't be parsed correctly in some python verions, + # see https://bugs.python.org/issue27657 + grpc_endpoint_with_scheme = ( + config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}" + ) + grpc_p = urlparse(grpc_endpoint_with_scheme) + grpc_host = grpc_p.hostname or "localhost" + grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051) + grpc_secure = grpc_p.scheme == "grpcs" + else: + # Infer from HTTP endpoint as fallback + grpc_host = host + grpc_secure = http_secure + grpc_port = 443 if grpc_secure else 50051 client = weaviate.connect_to_custom( http_host=host, @@ -432,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory): collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT or "", + grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/pipeline/__init__.py b/api/core/rag/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py new file mode 100644 index 0000000000..3d4d6f588d --- /dev/null +++ b/api/core/rag/pipeline/queue.py @@ -0,0 +1,79 @@ +import json +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel, ValidationError + +from extensions.ext_redis import redis_client + +_DEFAULT_TASK_TTL = 60 * 60 # 1 hour + + +class TaskWrapper(BaseModel): + data: Any + + def serialize(self) -> str: + return self.model_dump_json() + + @classmethod + def deserialize(cls, serialized_data: str) -> "TaskWrapper": + return cls.model_validate_json(serialized_data) + + +class TenantIsolatedTaskQueue: + """ + Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation. + It uses Redis list to store tasks, and Redis key to store task waiting flag. + Support tasks that can be serialized by json. + """ + + def __init__(self, tenant_id: str, unique_key: str): + self._tenant_id = tenant_id + self._unique_key = unique_key + self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}" + self._task_key = f"tenant_{unique_key}_task:{tenant_id}" + + def get_task_key(self): + return redis_client.get(self._task_key) + + def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL): + redis_client.setex(self._task_key, ttl, 1) + + def delete_task_key(self): + redis_client.delete(self._task_key) + + def push_tasks(self, tasks: Sequence[Any]): + serialized_tasks = [] + for task in tasks: + # Store str list directly, maintaining full compatibility for pipeline scenarios + if isinstance(task, str): + serialized_tasks.append(task) + else: + # Use TaskWrapper to do JSON serialization for non-string tasks + wrapper = TaskWrapper(data=task) + serialized_data = wrapper.serialize() + serialized_tasks.append(serialized_data) + + redis_client.lpush(self._queue, *serialized_tasks) + + def pull_tasks(self, count: int = 1) -> Sequence[Any]: + if count <= 0: + return [] + + tasks = [] + for _ in range(count): + serialized_task = redis_client.rpop(self._queue) + if not serialized_task: + break + + if isinstance(serialized_task, bytes): + serialized_task = serialized_task.decode("utf-8") + + try: + wrapper = TaskWrapper.deserialize(serialized_task) + tasks.append(wrapper.data) + except (json.JSONDecodeError, ValidationError, TypeError, ValueError): + # Fall back to raw string for legacy format or invalid JSON + tasks.append(serialized_task) + + return tasks diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 82616596f8..8ca4eabb7a 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -210,12 +210,13 @@ class Tool(ABC): meta=meta, ) - def create_json_message(self, object: dict) -> ToolInvokeMessage: + def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage: """ create a json message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) + type=ToolInvokeMessage.MessageType.JSON, + message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output), ) def create_variable_message( diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 15a4f0aafd..5b385f1bb2 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -129,6 +129,7 @@ class ToolInvokeMessage(BaseModel): class JsonMessage(BaseModel): json_object: dict + suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") class BlobMessage(BaseModel): blob: bytes diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 92d441b5ac..13fd579e20 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -245,6 +245,9 @@ class ToolEngine: + "you do not need to create it, just tell the user to check it now." ) elif response.type == ToolInvokeMessage.MessageType.JSON: + json_message = cast(ToolInvokeMessage.JsonMessage, response.message) + if json_message.suppress_output: + continue json_parts.append( json.dumps( safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 2cd46647a0..5703c19c88 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -117,7 +117,7 @@ class WorkflowTool(Tool): self._latest_usage = self._derive_usage_from_result(data) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) - yield self.create_json_message(outputs) + yield self.create_json_message(outputs, suppress_output=True) @property def latest_usage(self) -> LLMUsage: diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index d41a20dfd7..7fbaec9e70 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -153,7 +153,11 @@ class VariablePool(BaseModel): return None node_id, name = self._selector_to_keys(selector) - segment: Segment | None = self.variable_dictionary[node_id].get(name) + node_map = self.variable_dictionary.get(node_id) + if node_map is None: + return None + + segment: Segment | None = node_map.get(name) if segment is None: return None diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 8f6998119e..41b5eb20b5 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \ + -Q ${CELERY_QUEUES:-dataset,priority_dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \ --prefetch-multiplier=1 elif [[ "${MODE}" == "beat" ]]; then diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2e255c0a9b..78de76df7e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -50,6 +50,7 @@ from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow +from services.document_indexing_task_proxy import DocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -79,7 +80,6 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task -from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task @@ -1694,7 +1694,7 @@ class DocumentService: # trigger async task if document_ids: - document_indexing_task.delay(dataset.id, document_ids) + DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() if duplicate_document_ids: duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py new file mode 100644 index 0000000000..861c84b586 --- /dev/null +++ b/api/services/document_indexing_task_proxy.py @@ -0,0 +1,83 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from functools import cached_property + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + +logger = logging.getLogger(__name__) + + +class DocumentIndexingTaskProxy: + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + self._tenant_id = tenant_id + self._dataset_id = dataset_id + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): + logger.info("send dataset %s to direct queue", self._dataset_id) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): + logger.info("send dataset %s to tenant queue", self._dataset_id) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) + + def _send_to_default_tenant_queue(self): + self._send_to_tenant_queue(normal_document_indexing_task) + + def _send_to_priority_tenant_queue(self): + self._send_to_tenant_queue(priority_document_indexing_task) + + def _send_to_priority_direct_queue(self): + self._send_to_direct_queue(priority_document_indexing_task) + + def _dispatch(self): + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + self._dispatch() diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py new file mode 100644 index 0000000000..94dd7941da --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -0,0 +1,106 @@ +import json +import logging +from collections.abc import Callable, Sequence +from functools import cached_property + +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from services.feature_service import FeatureService +from services.file_service import FileService +from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task +from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task + +logger = logging.getLogger(__name__) + + +class RagPipelineTaskProxy: + # Default uploaded file name for rag pipeline invoke entities + _RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json" + + def __init__( + self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity] + ): + self._dataset_tenant_id = dataset_tenant_id + self._user_id = user_id + self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline") + + @cached_property + def features(self): + return FeatureService.get_features(self._dataset_tenant_id) + + def _upload_invoke_entities(self) -> str: + text = [item.model_dump() for item in self._rag_pipeline_invoke_entities] + # Convert list to proper JSON string + json_text = json.dumps(text) + upload_file = FileService(db.engine).upload_text( + json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id + ) + return upload_file.id + + def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): + logger.info("send file %s to direct queue", upload_file_id) + task_func.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file_id, + tenant_id=self._dataset_tenant_id, + ) + + def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): + logger.info("send file %s to tenant queue", upload_file_id) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks([upload_file_id]) + logger.info("push tasks: %s", upload_file_id) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file_id, + tenant_id=self._dataset_tenant_id, + ) + logger.info("init tasks: %s", upload_file_id) + + def _send_to_default_tenant_queue(self, upload_file_id: str): + self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) + + def _send_to_priority_tenant_queue(self, upload_file_id: str): + self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task) + + def _send_to_priority_direct_queue(self, upload_file_id: str): + self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task) + + def _dispatch(self): + upload_file_id = self._upload_invoke_entities() + if not upload_file_id: + raise ValueError("upload_file_id is empty") + + logger.info( + "dispatch args: %s - %s - %s", + self._dataset_tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + + # dispatch to different pipeline queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant isolation for sandbox plan + self._send_to_default_tenant_queue(upload_file_id) + else: + # dispatch to priority pipeline queue with tenant isolation for other plans + self._send_to_priority_tenant_queue(upload_file_id) + else: + # dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue(upload_file_id) + + def delay(self): + if not self._rag_pipeline_invoke_entities: + logger.warning( + "Received empty rag pipeline invoke entities, no tasks delivered: %s %s", + self._dataset_tenant_id, + self._user_id, + ) + return + self._dispatch() diff --git a/api/services/rag_pipeline/transform/website-crawl-general-economy.yml b/api/services/rag_pipeline/transform/website-crawl-general-economy.yml index 241d94c95d..a0f4b3bdd8 100644 --- a/api/services/rag_pipeline/transform/website-crawl-general-economy.yml +++ b/api/services/rag_pipeline/transform/website-crawl-general-economy.yml @@ -126,7 +126,7 @@ workflow: type: mixed value: '{{#rag.1752491761974.jina_use_sitemap#}}' plugin_id: langgenius/jina_datasource - provider_name: jina + provider_name: jinareader provider_type: website_crawl selected: false title: Jina Reader diff --git a/api/services/rag_pipeline/transform/website-crawl-general-high-quality.yml b/api/services/rag_pipeline/transform/website-crawl-general-high-quality.yml index 52b8f822c0..f58679fb6c 100644 --- a/api/services/rag_pipeline/transform/website-crawl-general-high-quality.yml +++ b/api/services/rag_pipeline/transform/website-crawl-general-high-quality.yml @@ -126,7 +126,7 @@ workflow: type: mixed value: '{{#rag.1752491761974.jina_use_sitemap#}}' plugin_id: langgenius/jina_datasource - provider_name: jina + provider_name: jinareader provider_type: website_crawl selected: false title: Jina Reader diff --git a/api/services/rag_pipeline/transform/website-crawl-parentchild.yml b/api/services/rag_pipeline/transform/website-crawl-parentchild.yml index 5d609bd12b..85b1cfd87d 100644 --- a/api/services/rag_pipeline/transform/website-crawl-parentchild.yml +++ b/api/services/rag_pipeline/transform/website-crawl-parentchild.yml @@ -419,7 +419,7 @@ workflow: type: mixed value: '{{#rag.1752491761974.jina_use_sitemap#}}' plugin_id: langgenius/jina_datasource - provider_name: jina + provider_name: jinareader provider_type: website_crawl selected: false title: Jina Reader diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 07f469de0e..fee4430612 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,11 +1,14 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -22,8 +25,24 @@ def document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_document_indexing_task or priority_document_indexing_task instead. + Usage: document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids) + _document_indexing(dataset_id, document_ids) + + +def _document_indexing(dataset_id: str, document_ids: Sequence[str]): + """ + Process document for tasks + :param dataset_id: + :param document_ids: + + Usage: _document_indexing(dataset_id, document_ids) + """ documents = [] start_at = time.perf_counter() @@ -87,3 +106,63 @@ def document_indexing_task(dataset_id: str, document_ids: list): logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +def _document_indexing_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _document_indexing(dataset_id, document_ids) + except Exception: + logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +@shared_task(queue="dataset") +def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process document + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task) + + +@shared_task(queue="priority_dataset") +def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Priority async process document + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 6de95a3b85..a7f61d9811 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -12,8 +12,10 @@ from celery import shared_task # type: ignore from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db from models import Account, Tenant @@ -22,6 +24,8 @@ from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.file_service import FileService +logger = logging.getLogger(__name__) + @shared_task(queue="priority_pipeline") def priority_rag_pipeline_run_task( @@ -69,6 +73,27 @@ def priority_rag_pipeline_run_task( logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) raise finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + + if next_file_ids: + for next_file_id in next_file_ids: + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + priority_rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, + tenant_id=tenant_id, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() file_service = FileService(db.engine) file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index f4a092d97e..92f1dfb73d 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -12,17 +12,20 @@ from celery import shared_task # type: ignore from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models import Account, Tenant from models.dataset import Pipeline from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.file_service import FileService +logger = logging.getLogger(__name__) + @shared_task(queue="pipeline") def rag_pipeline_run_task( @@ -70,26 +73,27 @@ def rag_pipeline_run_task( logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) raise finally: - tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}" - tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}" + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline") # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) - next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue) + next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) - if next_file_id: - # Process the next waiting task - # Keep the flag set to indicate a task is running - redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") - if isinstance(next_file_id, bytes) - else next_file_id, - tenant_id=tenant_id, - ) + if next_file_ids: + for next_file_id in next_file_ids: + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, + tenant_id=tenant_id, + ) else: # No more waiting tasks, clear the flag - redis_client.delete(tenant_pipeline_task_key) + tenant_isolated_task_queue.delete_task_key() file_service = FileService(db.engine) file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() diff --git a/api/tests/test_containers_integration_tests/core/rag/__init__.py b/api/tests/test_containers_integration_tests/core/rag/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/rag/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py new file mode 100644 index 0000000000..cdf390b327 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -0,0 +1,595 @@ +""" +Integration tests for TenantIsolatedTaskQueue using testcontainers. + +These tests verify the Redis-based task queue functionality with real Redis instances, +testing tenant isolation, task serialization, and queue operations in a realistic environment. +Includes compatibility tests for migrating from legacy string-only queues. + +All tests use generic naming to avoid coupling to specific business implementations. +""" + +import time +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import pytest +from faker import Faker + +from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue +from extensions.ext_redis import redis_client +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole + + +@dataclass +class TestTask: + """Test task data structure for testing complex object serialization.""" + + task_id: str + tenant_id: str + data: dict[str, Any] + metadata: dict[str, Any] + + +class TestTenantIsolatedTaskQueueIntegration: + """Integration tests for TenantIsolatedTaskQueue using testcontainers.""" + + @pytest.fixture + def fake(self): + """Faker instance for generating test data.""" + return Faker() + + @pytest.fixture + def test_tenant_and_account(self, db_session_with_containers, fake): + """Create test tenant and account for testing.""" + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return tenant, account + + @pytest.fixture + def test_queue(self, test_tenant_and_account): + """Create a generic test queue for testing.""" + tenant, _ = test_tenant_and_account + return TenantIsolatedTaskQueue(tenant.id, "test_queue") + + @pytest.fixture + def secondary_queue(self, test_tenant_and_account): + """Create a secondary test queue for testing isolation.""" + tenant, _ = test_tenant_and_account + return TenantIsolatedTaskQueue(tenant.id, "secondary_queue") + + def test_queue_initialization(self, test_tenant_and_account): + """Test queue initialization with correct key generation.""" + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "test-key") + + assert queue._tenant_id == tenant.id + assert queue._unique_key == "test-key" + assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" + assert queue._task_key == f"tenant_test-key_task:{tenant.id}" + + def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + """Test that different tenants have isolated queues.""" + tenant1, _ = test_tenant_and_account + + # Create second tenant + tenant2 = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant2) + db_session_with_containers.commit() + + queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key") + + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}" + assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}" + + def test_key_isolation(self, test_tenant_and_account): + """Test that different keys have isolated queues.""" + tenant, _ = test_tenant_and_account + queue1 = TenantIsolatedTaskQueue(tenant.id, "key1") + queue2 = TenantIsolatedTaskQueue(tenant.id, "key2") + + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}" + assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}" + + def test_task_key_operations(self, test_queue): + """Test task key operations (get, set, delete).""" + # Initially no task key should exist + assert test_queue.get_task_key() is None + + # Set task waiting time with default TTL + test_queue.set_task_waiting_time() + task_key = test_queue.get_task_key() + # Redis returns bytes, convert to string for comparison + assert task_key in (b"1", "1") + + # Set task waiting time with custom TTL + custom_ttl = 30 + test_queue.set_task_waiting_time(custom_ttl) + task_key = test_queue.get_task_key() + assert task_key in (b"1", "1") + + # Delete task key + test_queue.delete_task_key() + assert test_queue.get_task_key() is None + + def test_push_and_pull_string_tasks(self, test_queue): + """Test pushing and pulling string tasks.""" + tasks = ["task1", "task2", "task3"] + + # Push tasks + test_queue.push_tasks(tasks) + + # Pull tasks (FIFO order) + pulled_tasks = test_queue.pull_tasks(3) + + # Should get tasks in FIFO order (lpush + rpop = FIFO) + assert pulled_tasks == ["task1", "task2", "task3"] + + def test_push_and_pull_multiple_tasks(self, test_queue): + """Test pushing and pulling multiple tasks at once.""" + tasks = ["task1", "task2", "task3", "task4", "task5"] + + # Push tasks + test_queue.push_tasks(tasks) + + # Pull multiple tasks + pulled_tasks = test_queue.pull_tasks(3) + assert len(pulled_tasks) == 3 + assert pulled_tasks == ["task1", "task2", "task3"] + + # Pull remaining tasks + remaining_tasks = test_queue.pull_tasks(5) + assert len(remaining_tasks) == 2 + assert remaining_tasks == ["task4", "task5"] + + def test_push_and_pull_complex_objects(self, test_queue, fake): + """Test pushing and pulling complex object tasks.""" + # Create complex task objects as dictionaries (not dataclass instances) + tasks = [ + { + "task_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "data": { + "file_id": str(uuid4()), + "content": fake.text(), + "metadata": {"size": fake.random_int(1000, 10000)}, + }, + "metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)}, + }, + { + "task_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "data": { + "file_id": str(uuid4()), + "content": "测试中文内容", + "metadata": {"size": fake.random_int(1000, 10000)}, + }, + "metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]}, + }, + ] + + # Push complex tasks + test_queue.push_tasks(tasks) + + # Pull tasks + pulled_tasks = test_queue.pull_tasks(2) + assert len(pulled_tasks) == 2 + + # Verify deserialized tasks match original (FIFO order) + for i, pulled_task in enumerate(pulled_tasks): + original_task = tasks[i] # FIFO order + assert isinstance(pulled_task, dict) + assert pulled_task["task_id"] == original_task["task_id"] + assert pulled_task["tenant_id"] == original_task["tenant_id"] + assert pulled_task["data"] == original_task["data"] + assert pulled_task["metadata"] == original_task["metadata"] + + def test_mixed_task_types(self, test_queue, fake): + """Test pushing and pulling mixed string and object tasks.""" + string_task = "simple_string_task" + object_task = { + "task_id": str(uuid4()), + "dataset_id": str(uuid4()), + "document_ids": [str(uuid4()) for _ in range(3)], + } + + tasks = [string_task, object_task, "another_string"] + + # Push mixed tasks + test_queue.push_tasks(tasks) + + # Pull all tasks + pulled_tasks = test_queue.pull_tasks(3) + assert len(pulled_tasks) == 3 + + # Verify types and content + assert pulled_tasks[0] == string_task + assert isinstance(pulled_tasks[1], dict) + assert pulled_tasks[1] == object_task + assert pulled_tasks[2] == "another_string" + + def test_empty_queue_operations(self, test_queue): + """Test operations on empty queue.""" + # Pull from empty queue + tasks = test_queue.pull_tasks(5) + assert tasks == [] + + # Pull zero or negative count + assert test_queue.pull_tasks(0) == [] + assert test_queue.pull_tasks(-1) == [] + + def test_task_ttl_expiration(self, test_queue): + """Test task key TTL expiration.""" + # Set task with short TTL + short_ttl = 2 + test_queue.set_task_waiting_time(short_ttl) + + # Verify task key exists + assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1" + + # Wait for TTL to expire + time.sleep(short_ttl + 1) + + # Verify task key has expired + assert test_queue.get_task_key() is None + + def test_large_task_batch(self, test_queue, fake): + """Test handling large batches of tasks.""" + # Create large batch of tasks + large_batch = [] + for i in range(100): + task = { + "task_id": str(uuid4()), + "index": i, + "data": fake.text(max_nb_chars=100), + "metadata": {"batch_id": str(uuid4())}, + } + large_batch.append(task) + + # Push large batch + test_queue.push_tasks(large_batch) + + # Pull all tasks + pulled_tasks = test_queue.pull_tasks(100) + assert len(pulled_tasks) == 100 + + # Verify all tasks were retrieved correctly (FIFO order) + for i, task in enumerate(pulled_tasks): + assert isinstance(task, dict) + assert task["index"] == i # FIFO order + + def test_queue_operations_isolation(self, test_tenant_and_account, fake): + """Test concurrent operations on different queues.""" + tenant, _ = test_tenant_and_account + + # Create multiple queues for the same tenant + queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1") + queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2") + + # Push tasks to different queues + queue1.push_tasks(["task1_queue1", "task2_queue1"]) + queue2.push_tasks(["task1_queue2", "task2_queue2"]) + + # Verify queues are isolated + tasks1 = queue1.pull_tasks(2) + tasks2 = queue2.pull_tasks(2) + + assert tasks1 == ["task1_queue1", "task2_queue1"] + assert tasks2 == ["task1_queue2", "task2_queue2"] + assert tasks1 != tasks2 + + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + """Test TaskWrapper serialization and deserialization roundtrip.""" + # Create complex nested data + complex_data = { + "id": str(uuid4()), + "nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}}, + "metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]}, + } + + # Create wrapper and serialize + wrapper = TaskWrapper(data=complex_data) + serialized = wrapper.serialize() + + # Verify serialization + assert isinstance(serialized, str) + assert "测试中文" in serialized + assert "🚀" in serialized + + # Deserialize and verify + deserialized_wrapper = TaskWrapper.deserialize(serialized) + assert deserialized_wrapper.data == complex_data + + def test_error_handling_invalid_json(self, test_queue): + """Test error handling for invalid JSON in wrapped tasks.""" + # Manually create invalid JSON task (not a valid TaskWrapper JSON) + invalid_json_task = "invalid json data" + + # Push invalid task directly to Redis + redis_client.lpush(test_queue._queue, invalid_json_task) + + # Pull task - should fall back to string since it's not valid JSON + task = test_queue.pull_tasks(1) + assert task[0] == invalid_json_task + + def test_real_world_batch_processing_scenario(self, test_queue, fake): + """Test realistic batch processing scenario.""" + # Simulate batch processing tasks + batch_tasks = [] + for i in range(3): + task = { + "file_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "user_id": str(uuid4()), + "processing_config": { + "model": fake.random_element(["model_a", "model_b", "model_c"]), + "temperature": fake.random.uniform(0.1, 1.0), + "max_tokens": fake.random_int(1000, 4000), + }, + "metadata": { + "source": fake.random_element(["upload", "api", "webhook"]), + "priority": fake.random_element(["low", "normal", "high"]), + }, + } + batch_tasks.append(task) + + # Push tasks + test_queue.push_tasks(batch_tasks) + + # Process tasks in batches + batch_size = 2 + processed_tasks = [] + + while True: + batch = test_queue.pull_tasks(batch_size) + if not batch: + break + + processed_tasks.extend(batch) + + # Verify all tasks were processed + assert len(processed_tasks) == 3 + + # Verify task structure + for task in processed_tasks: + assert isinstance(task, dict) + assert "file_id" in task + assert "tenant_id" in task + assert "processing_config" in task + assert "metadata" in task + assert task["tenant_id"] == test_queue._tenant_id + + +class TestTenantIsolatedTaskQueueCompatibility: + """Compatibility tests for migrating from legacy string-only queues.""" + + @pytest.fixture + def fake(self): + """Faker instance for generating test data.""" + return Faker() + + @pytest.fixture + def test_tenant_and_account(self, db_session_with_containers, fake): + """Create test tenant and account for testing.""" + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return tenant, account + + def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + """ + Test compatibility with legacy queues containing only string data. + + This simulates the scenario where Redis queues already contain string data + from the old architecture, and we need to ensure the new code can read them. + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue") + + # Simulate legacy string data in Redis queue (using old format) + legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] + + # Manually push legacy strings directly to Redis (simulating old system) + for legacy_string in legacy_strings: + redis_client.lpush(queue._queue, legacy_string) + + # Verify new code can read legacy string data + pulled_tasks = queue.pull_tasks(5) + assert len(pulled_tasks) == 5 + + # Verify all tasks are strings (not wrapped) + for task in pulled_tasks: + assert isinstance(task, str) + assert task.startswith("legacy_task_") + + # Verify order (FIFO from Redis list) + expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] + assert pulled_tasks == expected_order + + def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + """ + Test complete migration scenario from legacy to new system. + + This simulates the real-world scenario where: + 1. Legacy system has string data in Redis + 2. New system starts processing the same queue + 3. Both legacy and new tasks coexist during migration + 4. New system can handle both formats seamlessly + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue") + + # Phase 1: Legacy system has data + legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)] + redis_client.lpush(queue._queue, *legacy_tasks) + + # Phase 2: New system starts processing legacy data + processed_legacy = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_legacy.extend(tasks) + + # Verify legacy data was processed correctly + assert len(processed_legacy) == 5 + for task in processed_legacy: + assert isinstance(task, str) + assert task.startswith("legacy_resource_") + + # Phase 3: New system adds new tasks (mixed types) + new_string_tasks = ["new_resource_1", "new_resource_2"] + new_object_tasks = [ + { + "resource_id": str(uuid4()), + "tenant_id": tenant.id, + "processing_type": "new_system", + "metadata": {"version": "2.0", "features": ["ai", "ml"]}, + }, + { + "resource_id": str(uuid4()), + "tenant_id": tenant.id, + "processing_type": "new_system", + "metadata": {"version": "2.0", "features": ["ai", "ml"]}, + }, + ] + + # Push new tasks using new system + queue.push_tasks(new_string_tasks) + queue.push_tasks(new_object_tasks) + + # Phase 4: Process all new tasks + processed_new = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_new.extend(tasks) + + # Verify new tasks were processed correctly + assert len(processed_new) == 4 + + string_tasks = [task for task in processed_new if isinstance(task, str)] + object_tasks = [task for task in processed_new if isinstance(task, dict)] + + assert len(string_tasks) == 2 + assert len(object_tasks) == 2 + + # Verify string tasks + for task in string_tasks: + assert task.startswith("new_resource_") + + # Verify object tasks + for task in object_tasks: + assert isinstance(task, dict) + assert "resource_id" in task + assert "tenant_id" in task + assert task["tenant_id"] == tenant.id + assert task["processing_type"] == "new_system" + + def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + """ + Test error recovery when legacy queue contains malformed data. + + This ensures the new system can gracefully handle corrupted or + malformed legacy data without crashing. + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue") + + # Create mix of valid and malformed legacy data + mixed_legacy_data = [ + "valid_legacy_task_1", + "valid_legacy_task_2", + "malformed_data_string", # This should be treated as string + "valid_legacy_task_3", + "invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON) + "valid_legacy_task_4", + ] + + # Manually push mixed data directly to Redis + redis_client.lpush(queue._queue, *mixed_legacy_data) + + # Process all tasks + processed_tasks = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_tasks.extend(tasks) + + # Verify all tasks were processed (no crashes) + assert len(processed_tasks) == 6 + + # Verify all tasks are strings (malformed data falls back to string) + for task in processed_tasks: + assert isinstance(task, str) + + # Verify valid tasks are preserved + valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")] + assert len(valid_tasks) == 4 + + # Verify malformed data is handled gracefully + malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")] + assert len(malformed_tasks) == 2 + assert "malformed_data_string" in malformed_tasks + assert "invalid_json_not_taskwrapper_format" in malformed_tasks diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 1329bba082..c015d7ec9c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -1,17 +1,33 @@ +from dataclasses import asdict from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document -from tasks.document_indexing_task import document_indexing_task +from tasks.document_indexing_task import ( + _document_indexing, # Core function + _document_indexing_with_tenant_queue, # Tenant queue wrapper function + document_indexing_task, # Deprecated old interface + normal_document_indexing_task, # New normal task + priority_document_indexing_task, # New priority task +) -class TestDocumentIndexingTask: - """Integration tests for document_indexing_task using testcontainers.""" +class TestDocumentIndexingTasks: + """Integration tests for document indexing tasks using testcontainers. + + This test class covers: + - Core _document_indexing function + - Deprecated document_indexing_task function + - New normal_document_indexing_task function + - New priority_document_indexing_task function + - Tenant queue wrapper _document_indexing_with_tenant_queue function + """ @pytest.fixture def mock_external_service_dependencies(self): @@ -224,7 +240,7 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in documents] # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify the expected outcomes # Verify indexing runner was called correctly @@ -232,10 +248,11 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were updated to parsing status - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -261,7 +278,7 @@ class TestDocumentIndexingTask: document_ids = [fake.uuid4() for _ in range(3)] # Act: Execute the task with non-existent dataset - document_indexing_task(non_existent_dataset_id, document_ids) + _document_indexing(non_existent_dataset_id, document_ids) # Assert: Verify no processing occurred mock_external_service_dependencies["indexing_runner"].assert_not_called() @@ -291,17 +308,18 @@ class TestDocumentIndexingTask: all_document_ids = existing_document_ids + non_existent_document_ids # Act: Execute the task with mixed document IDs - document_indexing_task(dataset.id, all_document_ids) + _document_indexing(dataset.id, all_document_ids) # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify only existing documents were updated - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -333,7 +351,7 @@ class TestDocumentIndexingTask: ) # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions @@ -341,10 +359,11 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were still updated to parsing status before the exception - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( self, db_session_with_containers, mock_external_service_dependencies @@ -407,17 +426,18 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in all_documents] # Act: Execute the task with mixed document states - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify all documents were updated to parsing status - for document in all_documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with all documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -470,15 +490,16 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify error handling - for document in all_documents: - db.session.refresh(document) - assert document.indexing_status == "error" - assert document.error is not None - assert "batch upload" in document.error - assert document.stopped_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error + assert updated_document.stopped_at is not None # Verify no indexing runner was called mock_external_service_dependencies["indexing_runner"].assert_not_called() @@ -503,17 +524,18 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in documents] # Act: Execute the task with billing disabled - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify successful processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were updated to parsing status - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( self, db_session_with_containers, mock_external_service_dependencies @@ -541,7 +563,7 @@ class TestDocumentIndexingTask: ) # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions @@ -549,7 +571,317 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were still updated to parsing status before the exception - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== + def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task (it only takes 2 parameters) + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_normal_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Act: Execute the new normal task + normal_document_indexing_task(tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_priority_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Act: Execute the new priority task + priority_document_indexing_task(tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_document_indexing_with_tenant_queue_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test _document_indexing_with_tenant_queue function with no waiting tasks. + + This test verifies: + - Core indexing logic execution (same as _document_indexing) + - Tenant queue cleanup when no waiting tasks + - Task function parameter passing + - Queue management after processing + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify core processing occurred (same as _document_indexing) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated (same as _document_indexing) + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] + assert len(processed_documents) == 2 + + # Verify task function was not called (no waiting tasks) + mock_task_func.delay.assert_not_called() + + def test_document_indexing_with_tenant_queue_with_waiting_tasks( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. + + This test verifies: + - Core indexing logic execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create real queue instance + queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Add waiting tasks to the real Redis queue + waiting_tasks = [ + DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]), + DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]), + ] + # Convert DocumentTask objects to dictionaries for serialization + waiting_task_dicts = [asdict(task) for task in waiting_tasks] + queue.push_tasks(waiting_task_dicts) + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify core processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify task function was called for each waiting task + assert mock_task_func.delay.call_count == 1 + + # Verify correct parameters for each call + calls = mock_task_func.delay.call_args_list + assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + + # Verify queue is empty after processing (tasks were pulled) + remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added + assert len(remaining_tasks) == 1 + + def test_document_indexing_with_tenant_queue_error_handling( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling in _document_indexing_with_tenant_queue using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error") + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create real queue instance + queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Add waiting task to the real Redis queue + waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]) + queue.push_tasks([asdict(waiting_task)]) + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify waiting task was still processed despite core processing error + mock_task_func.delay.assert_called_once() + + # Verify correct parameters for the call + call = mock_task_func.delay.call_args + assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_document_indexing_with_tenant_queue_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + dataset1, documents1 = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + dataset2, documents2 = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + + tenant1_id = dataset1.tenant_id + tenant2_id = dataset2.tenant_id + dataset1_id = dataset1.id + dataset2_id = dataset2.id + document_ids1 = [doc.id for doc in documents1] + document_ids2 = [doc.id for doc in documents2] + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create queue instances for both tenants + queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing") + queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing") + + # Add waiting tasks to both queues + waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"]) + waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"]) + + queue1.push_tasks([asdict(waiting_task1)]) + queue2.push_tasks([asdict(waiting_task2)]) + + # Act: Execute the wrapper function for tenant1 only + _document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func) + + # Assert: Verify core processing occurred for tenant1 + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only tenant1's waiting task was processed + mock_task_func.delay.assert_called_once() + call = mock_task_func.delay.call_args + assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]} + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py new file mode 100644 index 0000000000..c82162238c --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -0,0 +1,936 @@ +import json +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Pipeline +from models.workflow import Workflow +from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( + priority_rag_pipeline_run_task, + run_single_rag_pipeline_task, +) +from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task + + +class TestRagPipelineRunTasks: + """Integration tests for RAG pipeline run tasks using testcontainers. + + This test class covers: + - priority_rag_pipeline_run_task function + - rag_pipeline_run_task function + - run_single_rag_pipeline_task function + - Real Redis-based TenantIsolatedTaskQueue operations + - PipelineGenerator._generate method mocking and parameter validation + - File operations and cleanup + - Error handling and queue management + """ + + @pytest.fixture + def mock_pipeline_generator(self): + """Mock PipelineGenerator._generate method.""" + with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate: + # Mock the _generate method to return a simple response + mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}} + yield mock_generate + + @pytest.fixture + def mock_file_service(self): + """Mock FileService for file operations.""" + with ( + patch("services.file_service.FileService.get_file_content") as mock_get_content, + patch("services.file_service.FileService.delete_file") as mock_delete_file, + ): + yield { + "get_content": mock_get_content, + "delete_file": mock_delete_file, + } + + def _create_test_pipeline_and_workflow(self, db_session_with_containers): + """ + Helper method to create test pipeline and workflow for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant, pipeline, workflow) - Created entities + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + app_id=str(uuid.uuid4()), + type="workflow", + version="draft", + graph="{}", + features="{}", + marked_name=fake.company(), + marked_comment=fake.text(max_nb_chars=100), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db.session.add(workflow) + db.session.commit() + + # Create pipeline + pipeline = Pipeline( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + workflow_id=workflow.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + created_by=account.id, + ) + db.session.add(pipeline) + db.session.commit() + + # Refresh entities to ensure they're properly loaded + db.session.refresh(account) + db.session.refresh(tenant) + db.session.refresh(workflow) + db.session.refresh(pipeline) + + return account, tenant, pipeline, workflow + + def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2): + """ + Helper method to create RAG pipeline invoke entities for testing. + + Args: + account: Account instance + tenant: Tenant instance + pipeline: Pipeline instance + workflow: Workflow instance + count: Number of entities to create + + Returns: + list: List of RagPipelineInvokeEntity instances + """ + fake = Faker() + entities = [] + + for i in range(count): + # Create application generate entity + app_config = { + "app_id": str(uuid.uuid4()), + "app_name": fake.company(), + "mode": "workflow", + "workflow_id": workflow.id, + "tenant_id": tenant.id, + "app_mode": "workflow", + } + + application_generate_entity = { + "task_id": str(uuid.uuid4()), + "app_config": app_config, + "inputs": {"query": f"Test query {i}"}, + "files": [], + "user_id": account.id, + "stream": False, + "invoke_from": "published", + "workflow_execution_id": str(uuid.uuid4()), + "pipeline_config": { + "app_id": str(uuid.uuid4()), + "app_name": fake.company(), + "mode": "workflow", + "workflow_id": workflow.id, + "tenant_id": tenant.id, + "app_mode": "workflow", + }, + "datasource_type": "upload_file", + "datasource_info": {}, + "dataset_id": str(uuid.uuid4()), + "batch": "test_batch", + } + + entity = RagPipelineInvokeEntity( + pipeline_id=pipeline.id, + application_generate_entity=application_generate_entity, + user_id=account.id, + tenant_id=tenant.id, + workflow_id=workflow.id, + streaming=False, + workflow_execution_id=str(uuid.uuid4()), + workflow_thread_pool_id=str(uuid.uuid4()), + ) + entities.append(entity) + + return entities + + def _create_file_content_for_entities(self, entities): + """ + Helper method to create file content for RAG pipeline invoke entities. + + Args: + entities: List of RagPipelineInvokeEntity instances + + Returns: + str: JSON string containing serialized entities + """ + entities_data = [entity.model_dump() for entity in entities] + return json.dumps(entities_data) + + def test_priority_rag_pipeline_run_task_success( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test successful priority RAG pipeline run task execution. + + This test verifies: + - Task execution with multiple RAG pipeline invoke entities + - File content retrieval and parsing + - PipelineGenerator._generate method calls with correct parameters + - Thread pool execution + - File cleanup after execution + - Queue management with no waiting tasks + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Act: Execute the priority task + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify expected outcomes + # Verify file operations + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + + # Verify PipelineGenerator._generate was called for each entity + assert mock_pipeline_generator.call_count == 2 + + # Verify call parameters for each entity + calls = mock_pipeline_generator.call_args_list + for call in calls: + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_rag_pipeline_run_task_success( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test successful regular RAG pipeline run task execution. + + This test verifies: + - Task execution with multiple RAG pipeline invoke entities + - File content retrieval and parsing + - PipelineGenerator._generate method calls with correct parameters + - Thread pool execution + - File cleanup after execution + - Queue management with no waiting tasks + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Act: Execute the regular task + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify expected outcomes + # Verify file operations + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + + # Verify PipelineGenerator._generate was called for each entity + assert mock_pipeline_generator.call_count == 3 + + # Verify call parameters for each entity + calls = mock_pipeline_generator.call_args_list + for call in calls: + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_priority_rag_pipeline_run_task_with_waiting_tasks( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test priority RAG pipeline run task with waiting tasks in queue using real Redis. + + This test verifies: + - Core task execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting tasks to the real Redis queue + waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)] + queue.push_tasks(waiting_file_ids) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining + + def test_rag_pipeline_run_task_legacy_compatibility( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. + + This test simulates the scenario where: + - Old code writes file IDs directly to Redis list using lpush + - New worker processes these legacy queue entries + - Ensures backward compatibility during deployment transition + + Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id) + New format: TenantIsolatedTaskQueue.push_tasks([file_id]) + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Simulate legacy Redis queue format - direct file IDs in Redis list + from extensions.ext_redis import redis_client + + # Legacy queue key format (old code) + legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}" + legacy_task_key = f"tenant_pipeline_task:{tenant.id}" + + # Add legacy format data to Redis (simulating old code behavior) + legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)] + for file_id_legacy in legacy_file_ids: + redis_client.lpush(legacy_queue_key, file_id_legacy) + + # Set the task key to indicate there are waiting tasks (legacy behavior) + redis_client.set(legacy_task_key, 1, ex=60 * 60) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the priority task with new code but legacy queue data + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify that new code can process legacy queue entries + # The new TenantIsolatedTaskQueue should be able to read from the legacy format + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining + + # Cleanup: Remove legacy test data + redis_client.delete(legacy_queue_key) + redis_client.delete(legacy_task_key) + + def test_rag_pipeline_run_task_with_waiting_tasks( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with waiting tasks in queue using real Redis. + + This test verifies: + - Core task execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting tasks to the real Redis queue + waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)] + queue.push_tasks(waiting_file_ids) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining + + def test_priority_rag_pipeline_run_task_error_handling( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test error handling in priority RAG pipeline run task using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Mock PipelineGenerator to raise an exception + mock_pipeline_generator.side_effect = Exception("Pipeline generation failed") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task (should not raise exception) + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting task was still processed despite core processing error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_rag_pipeline_run_task_error_handling( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test error handling in regular RAG pipeline run task using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Mock PipelineGenerator to raise an exception + mock_pipeline_generator.side_effect = Exception("Pipeline generation failed") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task (should not raise exception) + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting task was still processed despite core processing error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_priority_rag_pipeline_run_task_tenant_isolation( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test tenant isolation in priority RAG pipeline run task using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers) + account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers) + + entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1) + entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1) + + file_content1 = self._create_file_content_for_entities(entities1) + file_content2 = self._create_file_content_for_entities(entities2) + + # Mock file service + file_id1 = str(uuid.uuid4()) + file_id2 = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = [file_content1, file_content2] + + # Use real Redis for TenantIsolatedTaskQueue + queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline") + + # Add waiting tasks to both queues + waiting_file_id1 = str(uuid.uuid4()) + waiting_file_id2 = str(uuid.uuid4()) + + queue1.push_tasks([waiting_file_id1]) + queue2.push_tasks([waiting_file_id2]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task for tenant1 only + priority_rag_pipeline_run_task(file_id1, tenant1.id) + + # Assert: Verify core processing occurred for tenant1 + assert mock_file_service["get_content"].call_count == 1 + assert mock_file_service["delete_file"].call_count == 1 + assert mock_pipeline_generator.call_count == 1 + + # Verify only tenant1's waiting task was processed + mock_delay.assert_called_once() + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert call_kwargs.get("tenant_id") == tenant1.id + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + + def test_rag_pipeline_run_task_tenant_isolation( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test tenant isolation in regular RAG pipeline run task using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers) + account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers) + + entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1) + entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1) + + file_content1 = self._create_file_content_for_entities(entities1) + file_content2 = self._create_file_content_for_entities(entities2) + + # Mock file service + file_id1 = str(uuid.uuid4()) + file_id2 = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = [file_content1, file_content2] + + # Use real Redis for TenantIsolatedTaskQueue + queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline") + + # Add waiting tasks to both queues + waiting_file_id1 = str(uuid.uuid4()) + waiting_file_id2 = str(uuid.uuid4()) + + queue1.push_tasks([waiting_file_id1]) + queue2.push_tasks([waiting_file_id2]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task for tenant1 only + rag_pipeline_run_task(file_id1, tenant1.id) + + # Assert: Verify core processing occurred for tenant1 + assert mock_file_service["get_content"].call_count == 1 + assert mock_file_service["delete_file"].call_count == 1 + assert mock_pipeline_generator.call_count == 1 + + # Verify only tenant1's waiting task was processed + mock_delay.assert_called_once() + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert call_kwargs.get("tenant_id") == tenant1.id + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + + def test_run_single_rag_pipeline_task_success( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test successful run_single_rag_pipeline_task execution. + + This test verifies: + - Single RAG pipeline task execution within Flask app context + - Entity validation and database queries + - PipelineGenerator._generate method call with correct parameters + - Proper Flask context handling + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + entity_data = entities[0].model_dump() + + # Act: Execute the single task + with flask_app_with_containers.app_context(): + run_single_rag_pipeline_task(entity_data, flask_app_with_containers) + + # Assert: Verify expected outcomes + # Verify PipelineGenerator._generate was called + assert mock_pipeline_generator.call_count == 1 + + # Verify call parameters + call = mock_pipeline_generator.call_args + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_run_single_rag_pipeline_task_entity_validation_error( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test run_single_rag_pipeline_task with invalid entity data. + + This test verifies: + - Proper error handling for invalid entity data + - Exception logging + - Function raises ValueError for missing entities + """ + # Arrange: Create entity data with valid UUIDs but non-existent entities + fake = Faker() + invalid_entity_data = { + "pipeline_id": str(uuid.uuid4()), + "application_generate_entity": { + "app_config": { + "app_id": str(uuid.uuid4()), + "app_name": "Test App", + "mode": "workflow", + "workflow_id": str(uuid.uuid4()), + }, + "inputs": {"query": "Test query"}, + "query": "Test query", + "response_mode": "blocking", + "user": str(uuid.uuid4()), + "files": [], + "conversation_id": str(uuid.uuid4()), + }, + "user_id": str(uuid.uuid4()), + "tenant_id": str(uuid.uuid4()), + "workflow_id": str(uuid.uuid4()), + "streaming": False, + "workflow_execution_id": str(uuid.uuid4()), + "workflow_thread_pool_id": str(uuid.uuid4()), + } + + # Act & Assert: Execute the single task with non-existent entities (should raise ValueError) + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Account .* not found"): + run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers) + + # Assert: Pipeline generator should not be called + mock_pipeline_generator.assert_not_called() + + def test_run_single_rag_pipeline_task_database_entity_not_found( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test run_single_rag_pipeline_task with non-existent database entities. + + This test verifies: + - Proper error handling for missing database entities + - Exception logging + - Function raises ValueError for missing entities + """ + # Arrange: Create test data with non-existent IDs + fake = Faker() + entity_data = { + "pipeline_id": str(uuid.uuid4()), + "application_generate_entity": { + "app_config": { + "app_id": str(uuid.uuid4()), + "app_name": "Test App", + "mode": "workflow", + "workflow_id": str(uuid.uuid4()), + }, + "inputs": {"query": "Test query"}, + "query": "Test query", + "response_mode": "blocking", + "user": str(uuid.uuid4()), + "files": [], + "conversation_id": str(uuid.uuid4()), + }, + "user_id": str(uuid.uuid4()), + "tenant_id": str(uuid.uuid4()), + "workflow_id": str(uuid.uuid4()), + "streaming": False, + "workflow_execution_id": str(uuid.uuid4()), + "workflow_thread_pool_id": str(uuid.uuid4()), + } + + # Act & Assert: Execute the single task with non-existent entities (should raise ValueError) + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Account .* not found"): + run_single_rag_pipeline_task(entity_data, flask_app_with_containers) + + # Assert: Pipeline generator should not be called + mock_pipeline_generator.assert_not_called() + + def test_priority_rag_pipeline_run_task_file_not_found( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test priority RAG pipeline run task with non-existent file. + + This test verifies: + - Proper error handling for missing files + - Exception logging + - Function raises Exception for file errors + - Queue management continues despite file errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + + # Mock file service to raise exception + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = Exception("File not found") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act & Assert: Execute the priority task (should raise Exception) + with pytest.raises(Exception, match="File not found"): + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_pipeline_generator.assert_not_called() + + # Verify waiting task was still processed despite file error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_rag_pipeline_run_task_file_not_found( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with non-existent file. + + This test verifies: + - Proper error handling for missing files + - Exception logging + - Function raises Exception for file errors + - Queue management continues despite file errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + + # Mock file service to raise exception + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = Exception("File not found") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act & Assert: Execute the regular task (should raise Exception) + with pytest.raises(Exception, match="File not found"): + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_pipeline_generator.assert_not_called() + + # Verify waiting task was still processed despite file error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 diff --git a/api/tests/unit_tests/core/helper/code_executor/__init__.py b/api/tests/unit_tests/core/helper/code_executor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/code_executor/javascript/__init__.py b/api/tests/unit_tests/core/helper/code_executor/javascript/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/code_executor/javascript/test_javascript_transformer.py b/api/tests/unit_tests/core/helper/code_executor/javascript/test_javascript_transformer.py new file mode 100644 index 0000000000..03f37756d7 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/javascript/test_javascript_transformer.py @@ -0,0 +1,12 @@ +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer + + +def test_get_runner_script(): + code = JavascriptCodeProvider.get_default_code() + inputs = {"arg1": "hello, ", "arg2": "world!"} + script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs) + script_lines = script.splitlines() + code_lines = code.splitlines() + # Check that the first lines of script are exactly the same as code + assert script_lines[: len(code_lines)] == code_lines diff --git a/api/tests/unit_tests/core/helper/code_executor/python3/__init__.py b/api/tests/unit_tests/core/helper/code_executor/python3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/code_executor/python3/test_python3_transformer.py b/api/tests/unit_tests/core/helper/code_executor/python3/test_python3_transformer.py new file mode 100644 index 0000000000..1166cb8892 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/python3/test_python3_transformer.py @@ -0,0 +1,12 @@ +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer + + +def test_get_runner_script(): + code = Python3CodeProvider.get_default_code() + inputs = {"arg1": "hello, ", "arg2": "world!"} + script = Python3TemplateTransformer.assemble_runner_script(code, inputs) + script_lines = script.splitlines() + code_lines = code.splitlines() + # Check that the first lines of script are exactly the same as code + assert script_lines[: len(code_lines)] == code_lines diff --git a/api/tests/unit_tests/core/rag/pipeline/test_queue.py b/api/tests/unit_tests/core/rag/pipeline/test_queue.py new file mode 100644 index 0000000000..cfdbb30f8f --- /dev/null +++ b/api/tests/unit_tests/core/rag/pipeline/test_queue.py @@ -0,0 +1,301 @@ +""" +Unit tests for TenantIsolatedTaskQueue. + +These tests verify the Redis-based task queue functionality for tenant-specific +task management with proper serialization and deserialization. +""" + +import json +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue + + +class TestTaskWrapper: + """Test cases for TaskWrapper serialization/deserialization.""" + + def test_serialize_simple_data(self): + """Test serialization of simple data types.""" + data = {"key": "value", "number": 42, "list": [1, 2, 3]} + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + assert isinstance(serialized, str) + + # Verify it's valid JSON + parsed = json.loads(serialized) + assert parsed["data"] == data + + def test_serialize_complex_data(self): + """Test serialization of complex nested data.""" + data = { + "nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}}, + "unicode": "测试中文", + "special_chars": "!@#$%^&*()", + } + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + parsed = json.loads(serialized) + assert parsed["data"] == data + + def test_deserialize_valid_data(self): + """Test deserialization of valid JSON data.""" + original_data = {"key": "value", "number": 42} + # Serialize using TaskWrapper to get the correct format + wrapper = TaskWrapper(data=original_data) + serialized = wrapper.serialize() + + wrapper = TaskWrapper.deserialize(serialized) + assert wrapper.data == original_data + + def test_deserialize_invalid_json(self): + """Test deserialization handles invalid JSON gracefully.""" + invalid_json = "{invalid json}" + + # Pydantic will raise ValidationError for invalid JSON + with pytest.raises(ValidationError): + TaskWrapper.deserialize(invalid_json) + + def test_serialize_ensure_ascii_false(self): + """Test that serialization preserves Unicode characters.""" + data = {"chinese": "中文测试", "emoji": "🚀"} + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + assert "中文测试" in serialized + assert "🚀" in serialized + + +class TestTenantIsolatedTaskQueue: + """Test cases for TenantIsolatedTaskQueue functionality.""" + + @pytest.fixture + def mock_redis_client(self): + """Mock Redis client for testing.""" + mock_redis = MagicMock() + return mock_redis + + @pytest.fixture + def sample_queue(self, mock_redis_client): + """Create a sample TenantIsolatedTaskQueue instance.""" + return TenantIsolatedTaskQueue("tenant-123", "test-key") + + def test_initialization(self, sample_queue): + """Test queue initialization with correct key generation.""" + assert sample_queue._tenant_id == "tenant-123" + assert sample_queue._unique_key == "test-key" + assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123" + assert sample_queue._task_key == "tenant_test-key_task:tenant-123" + + @patch("core.rag.pipeline.queue.redis_client") + def test_get_task_key_exists(self, mock_redis, sample_queue): + """Test getting task key when it exists.""" + mock_redis.get.return_value = "1" + + result = sample_queue.get_task_key() + + assert result == "1" + mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_get_task_key_not_exists(self, mock_redis, sample_queue): + """Test getting task key when it doesn't exist.""" + mock_redis.get.return_value = None + + result = sample_queue.get_task_key() + + assert result is None + mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue): + """Test setting task waiting flag with default TTL.""" + sample_queue.set_task_waiting_time() + + mock_redis.setex.assert_called_once_with( + "tenant_test-key_task:tenant-123", + 3600, # DEFAULT_TASK_TTL + 1, + ) + + @patch("core.rag.pipeline.queue.redis_client") + def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue): + """Test setting task waiting flag with custom TTL.""" + custom_ttl = 1800 + sample_queue.set_task_waiting_time(custom_ttl) + + mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1) + + @patch("core.rag.pipeline.queue.redis_client") + def test_delete_task_key(self, mock_redis, sample_queue): + """Test deleting task key.""" + sample_queue.delete_task_key() + + mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_string_list(self, mock_redis, sample_queue): + """Test pushing string tasks directly.""" + tasks = ["task1", "task2", "task3"] + + sample_queue.push_tasks(tasks) + + mock_redis.lpush.assert_called_once_with( + "tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3" + ) + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_mixed_types(self, mock_redis, sample_queue): + """Test pushing mixed string and object tasks.""" + tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"] + + sample_queue.push_tasks(tasks) + + # Verify lpush was called + mock_redis.lpush.assert_called_once() + call_args = mock_redis.lpush.call_args + + # Check queue name + assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123" + + # Check serialized tasks + serialized_tasks = call_args[0][1:] + assert len(serialized_tasks) == 3 + assert serialized_tasks[0] == "string_task" + assert serialized_tasks[2] == "another_string" + + # Check object task is serialized as TaskWrapper JSON (without prefix) + # It should be a valid JSON string that can be deserialized by TaskWrapper + wrapper = TaskWrapper.deserialize(serialized_tasks[1]) + assert wrapper.data == {"object_task": "data", "id": 123} + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_empty_list(self, mock_redis, sample_queue): + """Test pushing empty task list.""" + sample_queue.push_tasks([]) + + mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_default_count(self, mock_redis, sample_queue): + """Test pulling tasks with default count (1).""" + mock_redis.rpop.side_effect = ["task1", None] + + result = sample_queue.pull_tasks() + + assert result == ["task1"] + assert mock_redis.rpop.call_count == 1 + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_custom_count(self, mock_redis, sample_queue): + """Test pulling tasks with custom count.""" + # First test: pull 3 tasks + mock_redis.rpop.side_effect = ["task1", "task2", "task3", None] + + result = sample_queue.pull_tasks(3) + + assert result == ["task1", "task2", "task3"] + assert mock_redis.rpop.call_count == 3 + + # Reset mock for second test + mock_redis.reset_mock() + mock_redis.rpop.side_effect = ["task1", "task2", None] + + result = sample_queue.pull_tasks(3) + + assert result == ["task1", "task2"] + assert mock_redis.rpop.call_count == 3 + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_zero_count(self, mock_redis, sample_queue): + """Test pulling tasks with zero count returns empty list.""" + result = sample_queue.pull_tasks(0) + + assert result == [] + mock_redis.rpop.assert_not_called() + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_negative_count(self, mock_redis, sample_queue): + """Test pulling tasks with negative count returns empty list.""" + result = sample_queue.pull_tasks(-1) + + assert result == [] + mock_redis.rpop.assert_not_called() + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue): + """Test pulling tasks that include wrapped objects.""" + # Create a wrapped task + task_data = {"task_id": 123, "data": "test"} + wrapper = TaskWrapper(data=task_data) + wrapped_task = wrapper.serialize() + + mock_redis.rpop.side_effect = [ + "string_task", + wrapped_task.encode("utf-8"), # Simulate bytes from Redis + None, + ] + + result = sample_queue.pull_tasks(2) + + assert len(result) == 2 + assert result[0] == "string_task" + assert result[1] == {"task_id": 123, "data": "test"} + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue): + """Test pulling tasks with invalid JSON falls back to string.""" + # Invalid JSON string that cannot be deserialized + invalid_json = "invalid json data" + mock_redis.rpop.side_effect = [invalid_json, None] + + result = sample_queue.pull_tasks(1) + + assert result == [invalid_json] + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue): + """Test pulling tasks handles bytes from Redis correctly.""" + mock_redis.rpop.side_effect = [ + b"task1", # bytes + "task2", # string + None, + ] + + result = sample_queue.pull_tasks(2) + + assert result == ["task1", "task2"] + + @patch("core.rag.pipeline.queue.redis_client") + def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue): + """Test complex object serialization and deserialization roundtrip.""" + complex_task = { + "id": uuid4().hex, + "data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}}, + "metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]}, + } + + # Push the complex task + sample_queue.push_tasks([complex_task]) + + # Verify it was serialized as TaskWrapper JSON + call_args = mock_redis.lpush.call_args + wrapped_task = call_args[0][1] + # Verify it's a valid TaskWrapper JSON (starts with {"data":) + assert wrapped_task.startswith('{"data":') + + # Verify it can be deserialized + wrapper = TaskWrapper.deserialize(wrapped_task) + assert wrapper.data == complex_task + + # Simulate pulling it back + mock_redis.rpop.return_value = wrapped_task + result = sample_queue.pull_tasks(1) + + assert len(result) == 1 + assert result[0] == complex_task diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index f9de456b19..18f6753b05 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -111,3 +111,26 @@ class TestVariablePoolGetAndNestedAttribute: assert segment_false is not None assert isinstance(segment_false, BooleanSegment) assert segment_false.value is False + + +class TestVariablePoolGetNotModifyVariableDictionary: + _NODE_ID = "start" + _VAR_NAME = "name" + + def test_convert_to_template_should_not_introduce_extra_keys(self): + pool = VariablePool.empty() + pool.add([self._NODE_ID, self._VAR_NAME], 0) + pool.convert_template("The start.name is {{#start.name#}}") + assert "The start" not in pool.variable_dictionary + + def test_get_should_not_modify_variable_dictionary(self): + pool = VariablePool.empty() + pool.get([self._NODE_ID, self._VAR_NAME]) + assert len(pool.variable_dictionary) == 1 # only contains `sys` node id + assert "start" not in pool.variable_dictionary + + pool = VariablePool.empty() + pool.add([self._NODE_ID, self._VAR_NAME], "Joe") + pool.get([self._NODE_ID, "count"]) + start_subdict = pool.variable_dictionary[self._NODE_ID] + assert "count" not in start_subdict diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py new file mode 100644 index 0000000000..d9183be9fb --- /dev/null +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -0,0 +1,317 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_task_proxy import DocumentIndexingTaskProxy + + +class DocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DocumentIndexingTaskProxy: + """Create DocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDocumentIndexingTaskProxy: + """Test cases for DocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_default_tenant_queue(self, mock_task): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + def test_send_to_priority_tenant_queue(self, mock_task): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + def test_send_to_priority_direct_queue(self, mock_task): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + def test_document_task_dataclass(self): + """Test DocumentTask dataclass.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2"] + + # Act + task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) + + # Assert + assert task.tenant_id == tenant_id + assert task.dataset_id == dataset_id + assert task.document_ids == document_ids + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids diff --git a/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py new file mode 100644 index 0000000000..f5a48b1416 --- /dev/null +++ b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py @@ -0,0 +1,483 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy + + +class RagPipelineTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for RagPipelineTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_rag_pipeline_invoke_entity( + pipeline_id: str = "pipeline-123", + user_id: str = "user-456", + tenant_id: str = "tenant-789", + workflow_id: str = "workflow-101", + streaming: bool = True, + workflow_execution_id: str | None = None, + workflow_thread_pool_id: str | None = None, + ) -> RagPipelineInvokeEntity: + """Create RagPipelineInvokeEntity instance for testing.""" + return RagPipelineInvokeEntity( + pipeline_id=pipeline_id, + application_generate_entity={"key": "value"}, + user_id=user_id, + tenant_id=tenant_id, + workflow_id=workflow_id, + streaming=streaming, + workflow_execution_id=workflow_execution_id, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + @staticmethod + def create_rag_pipeline_task_proxy( + dataset_tenant_id: str = "tenant-123", + user_id: str = "user-456", + rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None, + ) -> RagPipelineTaskProxy: + """Create RagPipelineTaskProxy instance for testing.""" + if rag_pipeline_invoke_entities is None: + rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()] + return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + @staticmethod + def create_mock_upload_file(file_id: str = "file-123") -> Mock: + """Create mock upload file.""" + upload_file = Mock() + upload_file.id = file_id + return upload_file + + +class TestRagPipelineTaskProxy: + """Test cases for RagPipelineTaskProxy class.""" + + def test_initialization(self): + """Test RagPipelineTaskProxy initialization.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert proxy._dataset_tenant_id == dataset_tenant_id + assert proxy._user_id == user_id + assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "pipeline" + + def test_initialization_with_empty_entities(self): + """Test initialization with empty rag_pipeline_invoke_entities.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert proxy._dataset_tenant_id == dataset_tenant_id + assert proxy._user_id == user_id + assert proxy._rag_pipeline_invoke_entities == [] + + def test_initialization_with_multiple_entities(self): + """Test initialization with multiple rag_pipeline_invoke_entities.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [ + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"), + ] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert len(proxy._rag_pipeline_invoke_entities) == 3 + assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1" + assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2" + assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_upload_invoke_entities(self, mock_db, mock_file_service_class): + """Test _upload_invoke_entities method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + result = proxy._upload_invoke_entities() + + # Assert + assert result == "file-123" + mock_file_service_class.assert_called_once_with(mock_db.engine) + + # Verify upload_text was called with correct parameters + mock_file_service.upload_text.assert_called_once() + call_args = mock_file_service.upload_text.call_args + json_text, name, user_id, tenant_id = call_args[0] + + assert name == "rag_pipeline_invoke_entities.json" + assert user_id == "user-456" + assert tenant_id == "tenant-123" + + # Verify JSON content + parsed_json = json.loads(json_text) + assert len(parsed_json) == 1 + assert parsed_json[0]["pipeline_id"] == "pipeline-123" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class): + """Test _upload_invoke_entities method with multiple entities.""" + # Arrange + entities = [ + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"), + ] + proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities) + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + result = proxy._upload_invoke_entities() + + # Assert + assert result == "file-456" + + # Verify JSON content contains both entities + call_args = mock_file_service.upload_text.call_args + json_text = call_args[0][0] + parsed_json = json.loads(json_text) + assert len(parsed_json) == 2 + assert parsed_json[0]["pipeline_id"] == "pipeline-1" + assert parsed_json[1]["pipeline_id"] == "pipeline-2" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue() + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(upload_file_id, mock_task) + + # If sent to direct queue, tenant_isolated_task_queue should not be called + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + # Celery should be called directly + mock_task.delay.assert_called_once_with( + rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123" + ) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(upload_file_id, mock_task) + + # If task key exists, should push tasks to the queue + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id]) + # Celery should not be called directly + mock_task.delay.assert_not_called() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(upload_file_id, mock_task) + + # If no task key, should set task waiting time key first + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123" + ) + + # The first task should be sent to celery directly, so push tasks should not be called + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_default_tenant_queue(self, mock_task): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_tenant_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_default_tenant_queue(upload_file_id) + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") + def test_send_to_priority_tenant_queue(self, mock_task): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_tenant_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_priority_tenant_queue(upload_file_id) + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") + def test_send_to_priority_direct_queue(self, mock_task): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_direct_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_priority_direct_queue(upload_file_id) + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_enabled_non_sandbox_plan( + self, mock_db, mock_file_service_class, mock_feature_service + ): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is enabled with non-sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class): + """Test _dispatch method when upload_file_id is empty.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = Mock() + mock_upload_file.id = "" # Empty file ID + mock_file_service.upload_text.return_value = mock_upload_file + + # Act & Assert + with pytest.raises(ValueError, match="upload_file_id is empty"): + proxy._dispatch() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._dispatch = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy.delay() + + # Assert + proxy._dispatch.assert_called_once() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.logger") + def test_delay_method_with_empty_entities(self, mock_logger): + """Test delay method with empty rag_pipeline_invoke_entities.""" + # Arrange + proxy = RagPipelineTaskProxy("tenant-123", "user-456", []) + + # Act + proxy.delay() + + # Assert + mock_logger.warning.assert_called_once_with( + "Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456" + ) diff --git a/api/uv.lock b/api/uv.lock index 98316d229c..09d469fe80 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -666,44 +666,30 @@ wheels = [ [[package]] name = "brotli" -version = "1.1.0" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2f/c2/f9e977608bdf958650638c3f1e28f85a1b075f075ebbe77db8555463787b/Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724", size = 7372270 } +sdist = { url = "https://files.pythonhosted.org/packages/f7/16/c92ca344d646e71a43b8bb353f0a6490d7f6e06210f8554c8f874e454285/brotli-1.2.0.tar.gz", hash = "sha256:e310f77e41941c13340a95976fe66a8a95b01e783d430eeaf7a2f87e0a57dd0a", size = 7388632 } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/12/ad41e7fadd5db55459c4c401842b47f7fee51068f86dd2894dd0dcfc2d2a/Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc", size = 873068 }, - { url = "https://files.pythonhosted.org/packages/95/4e/5afab7b2b4b61a84e9c75b17814198ce515343a44e2ed4488fac314cd0a9/Brotli-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c8146669223164fc87a7e3de9f81e9423c67a79d6b3447994dfb9c95da16e2d6", size = 446244 }, - { url = "https://files.pythonhosted.org/packages/9d/e6/f305eb61fb9a8580c525478a4a34c5ae1a9bcb12c3aee619114940bc513d/Brotli-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30924eb4c57903d5a7526b08ef4a584acc22ab1ffa085faceb521521d2de32dd", size = 2906500 }, - { url = "https://files.pythonhosted.org/packages/3e/4f/af6846cfbc1550a3024e5d3775ede1e00474c40882c7bf5b37a43ca35e91/Brotli-1.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ceb64bbc6eac5a140ca649003756940f8d6a7c444a68af170b3187623b43bebf", size = 2943950 }, - { url = "https://files.pythonhosted.org/packages/b3/e7/ca2993c7682d8629b62630ebf0d1f3bb3d579e667ce8e7ca03a0a0576a2d/Brotli-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a469274ad18dc0e4d316eefa616d1d0c2ff9da369af19fa6f3daa4f09671fd61", size = 2918527 }, - { url = "https://files.pythonhosted.org/packages/b3/96/da98e7bedc4c51104d29cc61e5f449a502dd3dbc211944546a4cc65500d3/Brotli-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524f35912131cc2cabb00edfd8d573b07f2d9f21fa824bd3fb19725a9cf06327", size = 2845489 }, - { url = "https://files.pythonhosted.org/packages/e8/ef/ccbc16947d6ce943a7f57e1a40596c75859eeb6d279c6994eddd69615265/Brotli-1.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5b3cc074004d968722f51e550b41a27be656ec48f8afaeeb45ebf65b561481dd", size = 2914080 }, - { url = "https://files.pythonhosted.org/packages/80/d6/0bd38d758d1afa62a5524172f0b18626bb2392d717ff94806f741fcd5ee9/Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9", size = 2813051 }, - { url = "https://files.pythonhosted.org/packages/14/56/48859dd5d129d7519e001f06dcfbb6e2cf6db92b2702c0c2ce7d97e086c1/Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265", size = 2938172 }, - { url = "https://files.pythonhosted.org/packages/3d/77/a236d5f8cd9e9f4348da5acc75ab032ab1ab2c03cc8f430d24eea2672888/Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8", size = 2933023 }, - { url = "https://files.pythonhosted.org/packages/f1/87/3b283efc0f5cb35f7f84c0c240b1e1a1003a5e47141a4881bf87c86d0ce2/Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f", size = 2935871 }, - { url = "https://files.pythonhosted.org/packages/f3/eb/2be4cc3e2141dc1a43ad4ca1875a72088229de38c68e842746b342667b2a/Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757", size = 2847784 }, - { url = "https://files.pythonhosted.org/packages/66/13/b58ddebfd35edde572ccefe6890cf7c493f0c319aad2a5badee134b4d8ec/Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0", size = 3034905 }, - { url = "https://files.pythonhosted.org/packages/84/9c/bc96b6c7db824998a49ed3b38e441a2cae9234da6fa11f6ed17e8cf4f147/Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b", size = 2929467 }, - { url = "https://files.pythonhosted.org/packages/e7/71/8f161dee223c7ff7fea9d44893fba953ce97cf2c3c33f78ba260a91bcff5/Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50", size = 333169 }, - { url = "https://files.pythonhosted.org/packages/02/8a/fece0ee1057643cb2a5bbf59682de13f1725f8482b2c057d4e799d7ade75/Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1", size = 357253 }, - { url = "https://files.pythonhosted.org/packages/5c/d0/5373ae13b93fe00095a58efcbce837fd470ca39f703a235d2a999baadfbc/Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28", size = 815693 }, - { url = "https://files.pythonhosted.org/packages/8e/48/f6e1cdf86751300c288c1459724bfa6917a80e30dbfc326f92cea5d3683a/Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f", size = 422489 }, - { url = "https://files.pythonhosted.org/packages/06/88/564958cedce636d0f1bed313381dfc4b4e3d3f6015a63dae6146e1b8c65c/Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409", size = 873081 }, - { url = "https://files.pythonhosted.org/packages/58/79/b7026a8bb65da9a6bb7d14329fd2bd48d2b7f86d7329d5cc8ddc6a90526f/Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2", size = 446244 }, - { url = "https://files.pythonhosted.org/packages/e5/18/c18c32ecea41b6c0004e15606e274006366fe19436b6adccc1ae7b2e50c2/Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451", size = 2906505 }, - { url = "https://files.pythonhosted.org/packages/08/c8/69ec0496b1ada7569b62d85893d928e865df29b90736558d6c98c2031208/Brotli-1.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7f4bf76817c14aa98cc6697ac02f3972cb8c3da93e9ef16b9c66573a68014f91", size = 2944152 }, - { url = "https://files.pythonhosted.org/packages/ab/fb/0517cea182219d6768113a38167ef6d4eb157a033178cc938033a552ed6d/Brotli-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0c5516f0aed654134a2fc936325cc2e642f8a0e096d075209672eb321cff408", size = 2919252 }, - { url = "https://files.pythonhosted.org/packages/c7/53/73a3431662e33ae61a5c80b1b9d2d18f58dfa910ae8dd696e57d39f1a2f5/Brotli-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c3020404e0b5eefd7c9485ccf8393cfb75ec38ce75586e046573c9dc29967a0", size = 2845955 }, - { url = "https://files.pythonhosted.org/packages/55/ac/bd280708d9c5ebdbf9de01459e625a3e3803cce0784f47d633562cf40e83/Brotli-1.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4ed11165dd45ce798d99a136808a794a748d5dc38511303239d4e2363c0695dc", size = 2914304 }, - { url = "https://files.pythonhosted.org/packages/76/58/5c391b41ecfc4527d2cc3350719b02e87cb424ef8ba2023fb662f9bf743c/Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180", size = 2814452 }, - { url = "https://files.pythonhosted.org/packages/c7/4e/91b8256dfe99c407f174924b65a01f5305e303f486cc7a2e8a5d43c8bec3/Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248", size = 2938751 }, - { url = "https://files.pythonhosted.org/packages/5a/a6/e2a39a5d3b412938362bbbeba5af904092bf3f95b867b4a3eb856104074e/Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966", size = 2933757 }, - { url = "https://files.pythonhosted.org/packages/13/f0/358354786280a509482e0e77c1a5459e439766597d280f28cb097642fc26/Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9", size = 2936146 }, - { url = "https://files.pythonhosted.org/packages/80/f7/daf538c1060d3a88266b80ecc1d1c98b79553b3f117a485653f17070ea2a/Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb", size = 2848055 }, - { url = "https://files.pythonhosted.org/packages/ad/cf/0eaa0585c4077d3c2d1edf322d8e97aabf317941d3a72d7b3ad8bce004b0/Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111", size = 3035102 }, - { url = "https://files.pythonhosted.org/packages/d8/63/1c1585b2aa554fe6dbce30f0c18bdbc877fa9a1bf5ff17677d9cca0ac122/Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839", size = 2930029 }, - { url = "https://files.pythonhosted.org/packages/5f/3b/4e3fd1893eb3bbfef8e5a80d4508bec17a57bb92d586c85c12d28666bb13/Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0", size = 333276 }, - { url = "https://files.pythonhosted.org/packages/3d/d5/942051b45a9e883b5b6e98c041698b1eb2012d25e5948c58d6bf85b1bb43/Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951", size = 357255 }, + { url = "https://files.pythonhosted.org/packages/7a/ef/f285668811a9e1ddb47a18cb0b437d5fc2760d537a2fe8a57875ad6f8448/brotli-1.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:15b33fe93cedc4caaff8a0bd1eb7e3dab1c61bb22a0bf5bdfdfd97cd7da79744", size = 863110 }, + { url = "https://files.pythonhosted.org/packages/50/62/a3b77593587010c789a9d6eaa527c79e0848b7b860402cc64bc0bc28a86c/brotli-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:898be2be399c221d2671d29eed26b6b2713a02c2119168ed914e7d00ceadb56f", size = 445438 }, + { url = "https://files.pythonhosted.org/packages/cd/e1/7fadd47f40ce5549dc44493877db40292277db373da5053aff181656e16e/brotli-1.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:350c8348f0e76fff0a0fd6c26755d2653863279d086d3aa2c290a6a7251135dd", size = 1534420 }, + { url = "https://files.pythonhosted.org/packages/12/8b/1ed2f64054a5a008a4ccd2f271dbba7a5fb1a3067a99f5ceadedd4c1d5a7/brotli-1.2.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e1ad3fda65ae0d93fec742a128d72e145c9c7a99ee2fcd667785d99eb25a7fe", size = 1632619 }, + { url = "https://files.pythonhosted.org/packages/89/5a/7071a621eb2d052d64efd5da2ef55ecdac7c3b0c6e4f9d519e9c66d987ef/brotli-1.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:40d918bce2b427a0c4ba189df7a006ac0c7277c180aee4617d99e9ccaaf59e6a", size = 1426014 }, + { url = "https://files.pythonhosted.org/packages/26/6d/0971a8ea435af5156acaaccec1a505f981c9c80227633851f2810abd252a/brotli-1.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2a7f1d03727130fc875448b65b127a9ec5d06d19d0148e7554384229706f9d1b", size = 1489661 }, + { url = "https://files.pythonhosted.org/packages/f3/75/c1baca8b4ec6c96a03ef8230fab2a785e35297632f402ebb1e78a1e39116/brotli-1.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9c79f57faa25d97900bfb119480806d783fba83cd09ee0b33c17623935b05fa3", size = 1599150 }, + { url = "https://files.pythonhosted.org/packages/0d/1a/23fcfee1c324fd48a63d7ebf4bac3a4115bdb1b00e600f80f727d850b1ae/brotli-1.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:844a8ceb8483fefafc412f85c14f2aae2fb69567bf2a0de53cdb88b73e7c43ae", size = 1493505 }, + { url = "https://files.pythonhosted.org/packages/36/e5/12904bbd36afeef53d45a84881a4810ae8810ad7e328a971ebbfd760a0b3/brotli-1.2.0-cp311-cp311-win32.whl", hash = "sha256:aa47441fa3026543513139cb8926a92a8e305ee9c71a6209ef7a97d91640ea03", size = 334451 }, + { url = "https://files.pythonhosted.org/packages/02/8b/ecb5761b989629a4758c394b9301607a5880de61ee2ee5fe104b87149ebc/brotli-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:022426c9e99fd65d9475dce5c195526f04bb8be8907607e27e747893f6ee3e24", size = 369035 }, + { url = "https://files.pythonhosted.org/packages/11/ee/b0a11ab2315c69bb9b45a2aaed022499c9c24a205c3a49c3513b541a7967/brotli-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:35d382625778834a7f3061b15423919aa03e4f5da34ac8e02c074e4b75ab4f84", size = 861543 }, + { url = "https://files.pythonhosted.org/packages/e1/2f/29c1459513cd35828e25531ebfcbf3e92a5e49f560b1777a9af7203eb46e/brotli-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7a61c06b334bd99bc5ae84f1eeb36bfe01400264b3c352f968c6e30a10f9d08b", size = 444288 }, + { url = "https://files.pythonhosted.org/packages/3d/6f/feba03130d5fceadfa3a1bb102cb14650798c848b1df2a808356f939bb16/brotli-1.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:acec55bb7c90f1dfc476126f9711a8e81c9af7fb617409a9ee2953115343f08d", size = 1528071 }, + { url = "https://files.pythonhosted.org/packages/2b/38/f3abb554eee089bd15471057ba85f47e53a44a462cfce265d9bf7088eb09/brotli-1.2.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:260d3692396e1895c5034f204f0db022c056f9e2ac841593a4cf9426e2a3faca", size = 1626913 }, + { url = "https://files.pythonhosted.org/packages/03/a7/03aa61fbc3c5cbf99b44d158665f9b0dd3d8059be16c460208d9e385c837/brotli-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:072e7624b1fc4d601036ab3f4f27942ef772887e876beff0301d261210bca97f", size = 1419762 }, + { url = "https://files.pythonhosted.org/packages/21/1b/0374a89ee27d152a5069c356c96b93afd1b94eae83f1e004b57eb6ce2f10/brotli-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adedc4a67e15327dfdd04884873c6d5a01d3e3b6f61406f99b1ed4865a2f6d28", size = 1484494 }, + { url = "https://files.pythonhosted.org/packages/cf/57/69d4fe84a67aef4f524dcd075c6eee868d7850e85bf01d778a857d8dbe0a/brotli-1.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7a47ce5c2288702e09dc22a44d0ee6152f2c7eda97b3c8482d826a1f3cfc7da7", size = 1593302 }, + { url = "https://files.pythonhosted.org/packages/d5/3b/39e13ce78a8e9a621c5df3aeb5fd181fcc8caba8c48a194cd629771f6828/brotli-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af43b8711a8264bb4e7d6d9a6d004c3a2019c04c01127a868709ec29962b6036", size = 1487913 }, + { url = "https://files.pythonhosted.org/packages/62/28/4d00cb9bd76a6357a66fcd54b4b6d70288385584063f4b07884c1e7286ac/brotli-1.2.0-cp312-cp312-win32.whl", hash = "sha256:e99befa0b48f3cd293dafeacdd0d191804d105d279e0b387a32054c1180f3161", size = 334362 }, + { url = "https://files.pythonhosted.org/packages/1c/4e/bc1dcac9498859d5e353c9b153627a3752868a9d5f05ce8dedd81a2354ab/brotli-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:b35c13ce241abdd44cb8ca70683f20c0c079728a36a996297adb5334adfc1c44", size = 369115 }, ] [[package]] @@ -2430,6 +2416,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684 }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647 }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073 }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385 }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329 }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100 }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079 }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997 }, @@ -2439,6 +2427,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586 }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281 }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142 }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846 }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814 }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899 }, ] diff --git a/dev/start-worker b/dev/start-worker index 83d7bf0f3c..9cf448c9c6 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.." uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline + -P gevent -c 1 --loglevel INFO -Q dataset,priority_dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline diff --git a/docker/.env.example b/docker/.env.example index e6b578f52e..4c3be634a7 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -552,6 +552,7 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT=http://weaviate:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051 # The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. QDRANT_URL=http://qdrant:6333 @@ -1404,3 +1405,6 @@ ENABLE_CLEAN_MESSAGES=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true + +# Tenant isolated task queue configuration +TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7f17411dd0..7979f4aebd 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -178,6 +178,7 @@ x-shared-env: &shared-api-worker-env VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} + WEAVIATE_GRPC_ENDPOINT: ${WEAVIATE_GRPC_ENDPOINT:-grpc://weaviate:50051} QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} @@ -621,6 +622,7 @@ x-shared-env: &shared-api-worker-env ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true} + TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} services: # API service diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index b579f22d4b..7773edcdbb 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -759,4 +759,104 @@ export default translation` expect(result).not.toContain('Zbuduj inteligentnego agenta') }) }) + + describe('Performance and Scalability', () => { + it('should handle large translation files efficiently', async () => { + // Create a large translation file with 1000 keys + const largeContent = `const translation = { +${Array.from({ length: 1000 }, (_, i) => ` key${i}: 'value${i}',`).join('\n')} +} + +export default translation` + + fs.writeFileSync(path.join(testEnDir, 'large.ts'), largeContent) + + const startTime = Date.now() + const keys = await getKeysFromLanguage('en-US') + const endTime = Date.now() + + expect(keys.length).toBe(1000) + expect(endTime - startTime).toBeLessThan(1000) // Should complete in under 1 second + }) + + it('should handle multiple translation files concurrently', async () => { + // Create multiple files + for (let i = 0; i < 10; i++) { + const content = `const translation = { + key${i}: 'value${i}', + nested${i}: { + subkey: 'subvalue' + } +} + +export default translation` + fs.writeFileSync(path.join(testEnDir, `file${i}.ts`), content) + } + + const startTime = Date.now() + const keys = await getKeysFromLanguage('en-US') + const endTime = Date.now() + + expect(keys.length).toBe(20) // 10 files * 2 keys each + expect(endTime - startTime).toBeLessThan(500) + }) + }) + + describe('Unicode and Internationalization', () => { + it('should handle Unicode characters in keys and values', async () => { + const unicodeContent = `const translation = { + '中文键': '中文值', + 'العربية': 'قيمة', + 'emoji_😀': 'value with emoji 🎉', + 'mixed_中文_English': 'mixed value' +} + +export default translation` + + fs.writeFileSync(path.join(testEnDir, 'unicode.ts'), unicodeContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('unicode.中文键') + expect(keys).toContain('unicode.العربية') + expect(keys).toContain('unicode.emoji_😀') + expect(keys).toContain('unicode.mixed_中文_English') + }) + + it('should handle RTL language files', async () => { + const rtlContent = `const translation = { + مرحبا: 'Hello', + العالم: 'World', + nested: { + مفتاح: 'key' + } +} + +export default translation` + + fs.writeFileSync(path.join(testEnDir, 'rtl.ts'), rtlContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('rtl.مرحبا') + expect(keys).toContain('rtl.العالم') + expect(keys).toContain('rtl.nested.مفتاح') + }) + }) + + describe('Error Recovery', () => { + it('should handle syntax errors in translation files gracefully', async () => { + const invalidContent = `const translation = { + validKey: 'valid value', + invalidKey: 'missing quote, + anotherKey: 'another value' +} + +export default translation` + + fs.writeFileSync(path.join(testEnDir, 'invalid.ts'), invalidContent) + + await expect(getKeysFromLanguage('en-US')).rejects.toThrow() + }) + }) }) diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts index fa4986e63d..3eeba52943 100644 --- a/web/__tests__/navigation-utils.test.ts +++ b/web/__tests__/navigation-utils.test.ts @@ -286,4 +286,116 @@ describe('Navigation Utilities', () => { expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc') }) }) + + describe('Edge Cases and Error Handling', () => { + test('handles special characters in query parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?keyword=hello%20world&filter=type%3Apdf&tag=%E4%B8%AD%E6%96%87' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toContain('hello+world') + expect(path).toContain('type%3Apdf') + expect(path).toContain('%E4%B8%AD%E6%96%87') + }) + + test('handles duplicate query parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?tag=tag1&tag=tag2&tag=tag3' }, + writable: true, + }) + + const params = extractQueryParams(['tag']) + // URLSearchParams.get() returns the first value + expect(params.tag).toBe('tag1') + }) + + test('handles very long query strings', () => { + const longValue = 'a'.repeat(1000) + Object.defineProperty(window, 'location', { + value: { search: `?data=${longValue}` }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toContain(longValue) + expect(path.length).toBeGreaterThan(1000) + }) + + test('handles empty string values in query parameters', () => { + const path = createNavigationPathWithParams('/datasets/123/documents', { + page: 1, + keyword: '', + filter: '', + sort: 'name', + }) + + expect(path).toBe('/datasets/123/documents?page=1&sort=name') + expect(path).not.toContain('keyword=') + expect(path).not.toContain('filter=') + }) + + test('handles null and undefined values in mergeQueryParams', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=1&limit=10&keyword=test' }, + writable: true, + }) + + const merged = mergeQueryParams({ + keyword: null, + filter: undefined, + sort: 'name', + }) + const result = merged.toString() + + expect(result).toContain('page=1') + expect(result).toContain('limit=10') + expect(result).not.toContain('keyword') + expect(result).toContain('sort=name') + }) + + test('handles navigation with hash fragments', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=1', hash: '#section-2' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + // Should preserve query params but not hash + expect(path).toBe('/datasets/123/documents?page=1') + }) + + test('handles malformed query strings gracefully', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=1&invalid&limit=10&=value&key=' }, + writable: true, + }) + + const params = extractQueryParams(['page', 'limit', 'invalid', 'key']) + expect(params.page).toBe('1') + expect(params.limit).toBe('10') + // Malformed params should be handled by URLSearchParams + expect(params.invalid).toBe('') // for `&invalid` + expect(params.key).toBe('') // for `&key=` + }) + }) + + describe('Performance Tests', () => { + test('handles large number of query parameters efficiently', () => { + const manyParams = Array.from({ length: 50 }, (_, i) => `param${i}=value${i}`).join('&') + Object.defineProperty(window, 'location', { + value: { search: `?${manyParams}` }, + writable: true, + }) + + const startTime = Date.now() + const path = createNavigationPath('/datasets/123/documents') + const endTime = Date.now() + + expect(endTime - startTime).toBeLessThan(50) // Should be fast + expect(path).toContain('param0=value0') + expect(path).toContain('param49=value49') + }) + }) }) diff --git a/web/app/components/app-sidebar/app-operations.tsx b/web/app/components/app-sidebar/app-operations.tsx index 79c460419d..ab1069a29f 100644 --- a/web/app/components/app-sidebar/app-operations.tsx +++ b/web/app/components/app-sidebar/app-operations.tsx @@ -124,7 +124,7 @@ const AppOperations = ({ operations, gap }: { {t('common.operation.more')} - +
{moreOperations.map(item =>
{
{form.label}
{!form.required && ( -
{t('appDebug.variableTable.optional')}
+
{t('workflow.panel.optional')}
)}
)} diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx index dd65f0ce72..caf4e363ff 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx @@ -49,7 +49,7 @@ const InputsFormContent = ({ showTip }: Props) => {
{form.label}
{!form.required && ( -
{t('appDebug.variableTable.optional')}
+
{t('workflow.panel.optional')}
)}
)} diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index f0af893f0d..8ab007e66b 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { produce } from 'immer' @@ -45,7 +45,13 @@ const OpeningSettingModal = ({ const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) const [notIncludeKeys, setNotIncludeKeys] = useState([]) + const isSaveDisabled = useMemo(() => !tempValue.trim(), [tempValue]) + const handleSave = useCallback((ignoreVariablesCheck?: boolean) => { + // Prevent saving if opening statement is empty + if (isSaveDisabled) + return + if (!ignoreVariablesCheck) { const keys = getInputKeys(tempValue) const promptKeys = promptVariables.map(item => item.key) @@ -75,7 +81,7 @@ const OpeningSettingModal = ({ } }) onSave(newOpening) - }, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue]) + }, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue, isSaveDisabled]) const cancelAutoAddVar = useCallback(() => { hideConfirmAddVar() @@ -217,6 +223,7 @@ const OpeningSettingModal = ({ diff --git a/web/app/components/goto-anything/actions/commands/feedback.tsx b/web/app/components/goto-anything/actions/commands/forum.tsx similarity index 54% rename from web/app/components/goto-anything/actions/commands/feedback.tsx rename to web/app/components/goto-anything/actions/commands/forum.tsx index cce0aeb5f4..66237cb348 100644 --- a/web/app/components/goto-anything/actions/commands/feedback.tsx +++ b/web/app/components/goto-anything/actions/commands/forum.tsx @@ -4,27 +4,27 @@ import { RiFeedbackLine } from '@remixicon/react' import i18n from '@/i18n-config/i18next-config' import { registerCommands, unregisterCommands } from './command-bus' -// Feedback command dependency types -type FeedbackDeps = Record +// Forum command dependency types +type ForumDeps = Record /** - * Feedback command - Opens GitHub feedback discussions + * Forum command - Opens Dify community forum */ -export const feedbackCommand: SlashCommandHandler = { - name: 'feedback', - description: 'Open feedback discussions', +export const forumCommand: SlashCommandHandler = { + name: 'forum', + description: 'Open Dify community forum', mode: 'direct', // Direct execution function execute: () => { - const url = 'https://github.com/langgenius/dify/discussions/categories/feedbacks' + const url = 'https://forum.dify.ai' window.open(url, '_blank', 'noopener,noreferrer') }, async search(args: string, locale: string = 'en') { return [{ - id: 'feedback', - title: i18n.t('common.userProfile.communityFeedback', { lng: locale }), + id: 'forum', + title: i18n.t('common.userProfile.forum', { lng: locale }), description: i18n.t('app.gotoAnything.actions.feedbackDesc', { lng: locale }) || 'Open community feedback discussions', type: 'command' as const, icon: ( @@ -32,20 +32,20 @@ export const feedbackCommand: SlashCommandHandler = {
), - data: { command: 'navigation.feedback', args: { url: 'https://github.com/langgenius/dify/discussions/categories/feedbacks' } }, + data: { command: 'navigation.forum', args: { url: 'https://forum.dify.ai' } }, }] }, - register(_deps: FeedbackDeps) { + register(_deps: ForumDeps) { registerCommands({ - 'navigation.feedback': async (args) => { - const url = args?.url || 'https://github.com/langgenius/dify/discussions/categories/feedbacks' + 'navigation.forum': async (args) => { + const url = args?.url || 'https://forum.dify.ai' window.open(url, '_blank', 'noopener,noreferrer') }, }) }, unregister() { - unregisterCommands(['navigation.feedback']) + unregisterCommands(['navigation.forum']) }, } diff --git a/web/app/components/goto-anything/actions/commands/slash.tsx b/web/app/components/goto-anything/actions/commands/slash.tsx index e0d03d5019..b99215255f 100644 --- a/web/app/components/goto-anything/actions/commands/slash.tsx +++ b/web/app/components/goto-anything/actions/commands/slash.tsx @@ -7,7 +7,7 @@ import { useTheme } from 'next-themes' import { setLocaleOnClient } from '@/i18n-config' import { themeCommand } from './theme' import { languageCommand } from './language' -import { feedbackCommand } from './feedback' +import { forumCommand } from './forum' import { docsCommand } from './docs' import { communityCommand } from './community' import { accountCommand } from './account' @@ -34,7 +34,7 @@ export const registerSlashCommands = (deps: Record) => { // Register command handlers to the registry system with their respective dependencies slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme }) slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale }) - slashCommandRegistry.register(feedbackCommand, {}) + slashCommandRegistry.register(forumCommand, {}) slashCommandRegistry.register(docsCommand, {}) slashCommandRegistry.register(communityCommand, {}) slashCommandRegistry.register(accountCommand, {}) @@ -44,7 +44,7 @@ export const unregisterSlashCommands = () => { // Remove command handlers from registry system (automatically calls each command's unregister method) slashCommandRegistry.unregister('theme') slashCommandRegistry.unregister('language') - slashCommandRegistry.unregister('feedback') + slashCommandRegistry.unregister('forum') slashCommandRegistry.unregister('docs') slashCommandRegistry.unregister('community') slashCommandRegistry.unregister('account') diff --git a/web/app/components/header/account-dropdown/support.tsx b/web/app/components/header/account-dropdown/support.tsx index b165c5fcca..f354cc4ab0 100644 --- a/web/app/components/header/account-dropdown/support.tsx +++ b/web/app/components/header/account-dropdown/support.tsx @@ -1,5 +1,5 @@ import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { RiArrowRightSLine, RiArrowRightUpLine, RiChatSmile2Line, RiDiscordLine, RiFeedbackLine, RiMailSendLine, RiQuestionLine } from '@remixicon/react' +import { RiArrowRightSLine, RiArrowRightUpLine, RiChatSmile2Line, RiDiscordLine, RiDiscussLine, RiMailSendLine, RiQuestionLine } from '@remixicon/react' import { Fragment } from 'react' import Link from 'next/link' import { useTranslation } from 'react-i18next' @@ -86,10 +86,10 @@ export default function Support({ closeAccountDropdown }: SupportProps) { className={cn(itemClassName, 'group justify-between', 'data-[active]:bg-state-base-hover', )} - href='https://github.com/langgenius/dify/discussions/categories/feedbacks' + href='https://forum.dify.ai/' target='_blank' rel='noopener noreferrer'> - -
{t('common.userProfile.communityFeedback')}
+ +
{t('common.userProfile.forum')}
diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx index ccd0d8be1b..9f326fa198 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx @@ -73,7 +73,7 @@ const DetailHeader = ({ const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) const { - installation_id, + id, source, tenant_id, version, @@ -197,7 +197,7 @@ const DetailHeader = ({ const handleDelete = useCallback(async () => { showDeleting() - const res = await uninstallPlugin(installation_id) + const res = await uninstallPlugin(id) hideDeleting() if (res.success) { hideDeleteConfirm() @@ -207,7 +207,7 @@ const DetailHeader = ({ if (PluginType.tool.includes(category)) invalidateAllToolProviders() } - }, [showDeleting, installation_id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders]) + }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders]) return (
@@ -351,7 +351,6 @@ const DetailHeader = ({ content={
{t(`${i18nPrefix}.deleteContentLeft`)}{label[locale]}{t(`${i18nPrefix}.deleteContentRight`)}
- {/* {usedInApps > 0 && t(`${i18nPrefix}.usedInApps`, { num: usedInApps })} */}
} onCancel={hideDeleteConfirm} diff --git a/web/app/components/plugins/plugin-page/index.tsx b/web/app/components/plugins/plugin-page/index.tsx index d326fdf6e4..4b8444ab34 100644 --- a/web/app/components/plugins/plugin-page/index.tsx +++ b/web/app/components/plugins/plugin-page/index.tsx @@ -72,6 +72,8 @@ const PluginPage = ({ } }, [searchParams]) + const [uniqueIdentifier, setUniqueIdentifier] = useState(null) + const [dependencies, setDependencies] = useState([]) const bundleInfo = useMemo(() => { const info = searchParams.get(BUNDLE_INFO_KEY) @@ -99,6 +101,7 @@ const PluginPage = ({ useEffect(() => { (async () => { + setUniqueIdentifier(null) await sleep(100) if (packageId) { const { data } = await fetchManifestFromMarketPlace(encodeURIComponent(packageId)) @@ -108,6 +111,7 @@ const PluginPage = ({ version: version.version, icon: `${MARKETPLACE_API_PREFIX}/plugins/${plugin.org}/${plugin.name}/icon`, }) + setUniqueIdentifier(packageId) showInstallFromMarketplace() return } @@ -283,10 +287,10 @@ const PluginPage = ({ )} { - isShowInstallFromMarketplace && ( + isShowInstallFromMarketplace && uniqueIdentifier && ( = ({ : promptConfig.prompt_variables.map(item => (
{item.type !== 'checkbox' && ( - +
+
{item.name}
+ {!item.required && {t('workflow.panel.optional')}} +
)}
{item.type === 'select' && ( @@ -115,7 +118,7 @@ const RunOnce: FC = ({ {item.type === 'string' && ( ) => { handleInputsChange({ ...inputsRef.current, [item.key]: e.target.value }) }} maxLength={item.max_length || DEFAULT_VALUE_MAX_LEN} @@ -124,7 +127,7 @@ const RunOnce: FC = ({ {item.type === 'paragraph' && (