From ad870de55439e5d88edae5a8854d9c7a3ba31d1d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 12 Sep 2025 15:35:13 +0800 Subject: [PATCH 01/12] add dataset service api enable --- api/controllers/console/datasets/datasets.py | 12 ++++++ .../service_api/dataset/metadata.py | 4 +- api/controllers/service_api/wraps.py | 41 +++++++++++++++++++ ..._1429-0b2ca375fabe_add_pipeline_info_18.py | 35 ++++++++++++++++ api/models/dataset.py | 1 + api/services/dataset_service.py | 11 +++++ 6 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/2025_09_12_1429-0b2ca375fabe_add_pipeline_info_18.py diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 3834daa007..ef1fc5a958 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -650,6 +650,17 @@ class DatasetApiDeleteApi(Resource): return {"result": "success"}, 204 +class DatasetEnableApiApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id, status): + dataset_id_str = str(dataset_id) + + DatasetService.update_dataset_api_status(dataset_id_str, status == "enable") + + return {"result": "success"}, 200 + class DatasetApiBaseUrlApi(Resource): @setup_required @@ -816,6 +827,7 @@ api.add_resource(DatasetRelatedAppListApi, "/datasets//related- api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetEnableApiApi, "/datasets//") api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index c2df97eaec..c6032048e6 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): return 204 -@service_api_ns.route("/datasets/metadata/built-in") +@service_api_ns.route("/datasets//metadata/built-in") class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): @service_api_ns.doc("get_built_in_fields") @service_api_ns.doc(description="Get all built-in metadata fields") @@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - def get(self, tenant_id): + def get(self, tenant_id, dataset_id): """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() return {"fields": built_in_fields}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 14291578d5..e8816c74a9 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -193,6 +193,47 @@ def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): + # get url path dataset_id from positional args or kwargs + # Flask passes URL path parameters as positional arguments + dataset_id = None + + # First try to get from kwargs (explicit parameter) + dataset_id = kwargs.get("dataset_id") + + # If not in kwargs, try to extract from positional args + if not dataset_id and args: + # For class methods: args[0] is self, args[1] is dataset_id (if exists) + # Check if first arg is likely a class instance (has __dict__ or __class__) + if len(args) > 1 and hasattr(args[0], '__dict__'): + # This is a class method, dataset_id should be in args[1] + potential_id = args[1] + # Validate it's a string-like UUID, not another object + try: + # Try to convert to string and check if it's a valid UUID format + str_id = str(potential_id) + # Basic check: UUIDs are 36 chars with hyphens + if len(str_id) == 36 and str_id.count('-') == 4: + dataset_id = str_id + except: + pass + elif len(args) > 0: + # Not a class method, check if args[0] looks like a UUID + potential_id = args[0] + try: + str_id = str(potential_id) + if len(str_id) == 36 and str_id.count('-') == 4: + dataset_id = str_id + except: + pass + + # Validate dataset if dataset_id is provided + if dataset_id: + dataset_id = str(dataset_id) + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise NotFound("Dataset not found.") + if not dataset.enable_api: + raise Forbidden("Dataset api access is not enabled.") api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) diff --git a/api/migrations/versions/2025_09_12_1429-0b2ca375fabe_add_pipeline_info_18.py b/api/migrations/versions/2025_09_12_1429-0b2ca375fabe_add_pipeline_info_18.py new file mode 100644 index 0000000000..4d8be75b5a --- /dev/null +++ b/api/migrations/versions/2025_09_12_1429-0b2ca375fabe_add_pipeline_info_18.py @@ -0,0 +1,35 @@ +"""add_pipeline_info_18 + +Revision ID: 0b2ca375fabe +Revises: b45e25c2d166 +Create Date: 2025-09-12 14:29:38.078589 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0b2ca375fabe' +down_revision = 'b45e25c2d166' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('enable_api', sa.Boolean(), server_default=sa.text('true'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('enable_api') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 4674ef81e6..0cd53138cc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,7 @@ class Dataset(Base): runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) pipeline_id = db.Column(StringUUID, nullable=True) chunk_structure = db.Column(db.String(255), nullable=True) + enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property def total_documents(self): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f0b800842c..f0157db0f9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -916,6 +916,17 @@ class DatasetService: .all() ) + @staticmethod + def update_dataset_api_status(dataset_id: str, status: bool): + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + dataset.enable_api = status + dataset.updated_by = current_user.id + dataset.updated_at = naive_utc_now() + db.session.commit() + + @staticmethod def get_dataset_auto_disable_logs(dataset_id: str): assert isinstance(current_user, Account) From 80c32a130ff558062eb17f10fb5af564122311fc Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Sun, 14 Sep 2025 20:43:49 +0800 Subject: [PATCH 02/12] add dataset service api enable --- .../dataset/rag_pipeline/__init__.py | 0 .../rag_pipeline/rag_pipeline_workflow.py | 234 ++++++++++++++++++ .../app/apps/pipeline/pipeline_generator.py | 37 ++- .../entities/rag_pipeline_invoke_entities.py | 14 ++ api/services/file_service.py | 23 +- .../entity/pipeline_service_api_entities.py | 19 ++ .../priority_rag_pipeline_run_task.py | 167 +++++++++++++ .../rag_pipeline/rag_pipeline_run_task.py | 223 ++++++++++------- 8 files changed, 625 insertions(+), 92 deletions(-) create mode 100644 api/controllers/service_api/dataset/rag_pipeline/__init__.py create mode 100644 api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py create mode 100644 api/core/app/entities/rag_pipeline_invoke_entities.py create mode 100644 api/services/rag_pipeline/entity/pipeline_service_api_entities.py create mode 100644 api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py diff --git a/api/controllers/service_api/dataset/rag_pipeline/__init__.py b/api/controllers/service_api/dataset/rag_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py new file mode 100644 index 0000000000..75e2ba9f63 --- /dev/null +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -0,0 +1,234 @@ +import string +import uuid +from collections.abc import Generator +from typing import Any + +from flask import request +from flask_restx import reqparse +from flask_restx.reqparse import ParseResult, RequestParser +from werkzeug.exceptions import Forbidden + +import services +from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError +from controllers.service_api import service_api_ns +from controllers.service_api.dataset.error import PipelineRunError +from controllers.service_api.wraps import DatasetApiResource +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from libs import helper +from libs.login import current_user +from models.account import Account +from models.dataset import Pipeline +from models.engine import db +from services.errors.file import FileTooLargeError, UnsupportedFileTypeError +from services.file_service import FileService +from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity +from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") +class DatasourcePluginsApi(DatasetApiResource): + """Resource for datasource plugins.""" + + @service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins") + @service_api_ns.doc(description="List all datasource plugins for a rag pipeline") + @service_api_ns.doc( + path={ + "dataset_id": "Dataset ID", + } + ) + @service_api_ns.doc( + params={ + "is_published": "Whether to get published or draft datasource plugins " + "(true for published, false for draft, default: true)" + } + ) + @service_api_ns.doc( + responses={ + 200: "Datasource plugins retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) + def get(self, tenant_id: str, dataset_id: str): + """Resource for getting datasource plugins.""" + # Get query parameter to determine published or draft + is_published: bool = request.args.get("is_published", default=True, type=bool) + + rag_pipeline_service: RagPipelineService = RagPipelineService() + datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( + tenant_id=tenant_id, + dataset_id=dataset_id, + is_published=is_published + ) + return datasource_plugins, 200 + +@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run") +class DatasourceNodeRunApi(DatasetApiResource): + """Resource for datasource node run.""" + + @service_api_ns.doc(shortcut="pipeline_datasource_node_run") + @service_api_ns.doc(description="Run a datasource node for a rag pipeline") + @service_api_ns.doc( + path={ + "dataset_id": "Dataset ID", + } + ) + @service_api_ns.doc( + body={ + "inputs": "User input variables", + "datasource_type": "Datasource type, e.g. online_document", + "credential_id": "Credential ID", + "is_published": "Whether to get published or draft datasource plugins " + "(true for published, false for draft, default: true)" + } + ) + @service_api_ns.doc( + responses={ + 200: "Datasource node run successfully", + 401: "Unauthorized - invalid API token", + } + ) + def post(self, tenant_id: str, dataset_id: str, node_id: str): + """Resource for getting datasource plugins.""" + # Get query parameter to determine published or draft + parser: RequestParser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("credential_id", type=str, required=False, location="json") + parser.add_argument("is_published", type=bool, required=True, location="json") + args: ParseResult = parser.parse_args() + + datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) + assert isinstance(current_user, Account) + rag_pipeline_service: RagPipelineService = RagPipelineService() + pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=datasource_node_run_api_entity.inputs, + account=current_user, + datasource_type=datasource_node_run_api_entity.datasource_type, + is_published=datasource_node_run_api_entity.is_published, + credential_id=datasource_node_run_api_entity.credential_id, + ) + ) + ) + + +@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run") +class PipelineRunApi(DatasetApiResource): + """Resource for datasource node run.""" + + @service_api_ns.doc(shortcut="pipeline_datasource_node_run") + @service_api_ns.doc(description="Run a datasource node for a rag pipeline") + @service_api_ns.doc( + path={ + "dataset_id": "Dataset ID", + } + ) + @service_api_ns.doc( + body={ + "inputs": "User input variables", + "datasource_type": "Datasource type, e.g. online_document", + "datasource_info_list": "Datasource info list", + "start_node_id": "Start node ID", + "is_published": "Whether to get published or draft datasource plugins " + "(true for published, false for draft, default: true)", + "streaming": "Whether to stream the response(streaming or blocking), default: streaming" + } + ) + @service_api_ns.doc( + responses={ + 200: "Pipeline run successfully", + 401: "Unauthorized - invalid API token", + } + ) + def post(self, tenant_id: str, dataset_id: str): + """Resource for running a rag pipeline.""" + parser: RequestParser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") + parser.add_argument("is_published", type=bool, required=True, default=True, location="json") + parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json") + args: ParseResult = parser.parse_args() + + if not isinstance(current_user, Account): + raise Forbidden() + + rag_pipeline_service: RagPipelineService = RagPipelineService() + pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + try: + response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( + pipeline=pipeline, + user=current_user, + args=args, + invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER, + streaming=args.get("response_mode") == "streaming", + ) + + return helper.compact_generate_response(response) + except Exception as ex: + raise PipelineRunError(description=str(ex)) + + +@service_api_ns.route("/datasets/pipeline/file-upload") +class KnowledgebasePipelineFileUploadApi(DatasetApiResource): + """Resource for uploading a file to a knowledgebase pipeline.""" + + @service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload") + @service_api_ns.doc(description="Upload a file to a knowledgebase pipeline") + @service_api_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Bad request - no file or invalid file", + 401: "Unauthorized - invalid API token", + 413: "File too large", + 415: "Unsupported file type", + + } + ) + def post(self, tenant_id: str): + """Upload a file for use in conversations. + + Accepts a single file upload via multipart/form-data. + """ + # check file + if "file" not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + file = request.files["file"] + if not file.mimetype: + raise UnsupportedFileTypeError() + + if not file.filename: + raise FilenameNotExistsError + + try: + upload_file = FileService(db.engine).upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 161ef3201f..34a27b954c 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -25,6 +25,7 @@ from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.datasource.entities.datasource_entities import ( DatasourceProviderType, @@ -41,6 +42,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db +from extensions.ext_redis import redis_client from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline @@ -48,7 +50,10 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService from services.datasource_provider_service import DatasourceProviderService +from services.feature_service import FeatureService +from services.file_service import FileService from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService +from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task logger = logging.getLogger(__name__) @@ -147,6 +152,7 @@ class PipelineGenerator(BaseAppGenerator): db.session.commit() # run in child thread + rag_pipeline_invoke_entities = [] for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = None @@ -223,7 +229,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id=workflow_thread_pool_id, ) else: - rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity( pipeline_id=pipeline.id, user_id=user.id, tenant_id=pipeline.tenant_id, @@ -232,7 +238,36 @@ class PipelineGenerator(BaseAppGenerator): workflow_execution_id=workflow_run_id, workflow_thread_pool_id=workflow_thread_pool_id, application_generate_entity=application_generate_entity.model_dump(), + )) + + if rag_pipeline_invoke_entities: + # store the rag_pipeline_invoke_entities to object storage + text = [item.model_dump() for item in rag_pipeline_invoke_entities] + name = "rag_pipeline_invoke_entities.json" + # Convert list to proper JSON string + json_text = json.dumps(text) + upload_file = FileService(db.engine).upload_text(json_text, name) + features = FeatureService.get_features(dataset.tenant_id) + if features.billing.subscription.plan == "sandbox": + tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" + tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" + + if redis_client.get(tenant_pipeline_task_key): + # Add to waiting queue using List operations (lpush) + redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id) + else: + # Set flag and execute task + redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60) + rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file.id, + tenant_id=dataset.tenant_id, + ) + + else: + priority_rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file.id, ) + # return batch, dataset, documents return { "batch": batch, diff --git a/api/core/app/entities/rag_pipeline_invoke_entities.py b/api/core/app/entities/rag_pipeline_invoke_entities.py new file mode 100644 index 0000000000..b26f496c8a --- /dev/null +++ b/api/core/app/entities/rag_pipeline_invoke_entities.py @@ -0,0 +1,14 @@ +from typing import Any + +from pydantic import BaseModel + + +class RagPipelineInvokeEntity(BaseModel): + pipeline_id: str + application_generate_entity: dict[str, Any] + user_id: str + tenant_id: str + workflow_id: str + streaming: bool + workflow_execution_id: str | None = None + workflow_thread_pool_id: str | None = None \ No newline at end of file diff --git a/api/services/file_service.py b/api/services/file_service.py index 68fca8020f..894b485cce 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -120,8 +120,7 @@ class FileService: return file_size <= file_size_limit - @staticmethod - def upload_text(text: str, text_name: str) -> UploadFile: + def upload_text(self, text: str, text_name: str) -> UploadFile: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -225,3 +224,23 @@ class FileService: generator = storage.load(upload_file.key) return generator, upload_file.mime_type + + def get_file_content(self, file_id: str) -> str: + with self._session_maker(expire_on_commit=False) as session: + upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found") + content = storage.load(upload_file.key) + + return content.decode("utf-8") + + def delete_file(self, file_id: str): + with self._session_maker(expire_on_commit=False) as session: + upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return + storage.delete(upload_file.key) + session.delete(upload_file) + session.commit() \ No newline at end of file diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py new file mode 100644 index 0000000000..f1718e3cc8 --- /dev/null +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -0,0 +1,19 @@ +from typing import Any, Mapping, Optional +from pydantic import BaseModel + + +class DatasourceNodeRunApiEntity(BaseModel): + pipeline_id: str + node_id: str + inputs: Mapping[str, Any] + datasource_type: str + credential_id: Optional[str] = None + is_published: bool + +class PipelineRunApiEntity(BaseModel): + inputs: Mapping[str, Any] + datasource_type: str + datasource_info_list: list[Mapping[str, Any]] + start_node_id: str + is_published: bool + response_mode: str \ No newline at end of file diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py new file mode 100644 index 0000000000..a22d77ec17 --- /dev/null +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -0,0 +1,167 @@ +import contextvars +import json +import logging +import threading +import time +import uuid +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +import click +from celery import shared_task # type: ignore +from flask import current_app, g +from sqlalchemy.orm import Session, sessionmaker + +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from extensions.ext_database import db +from models.account import Account, Tenant +from models.dataset import Pipeline +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom +from services.file_service import FileService + + +@shared_task(queue="priority_pipeline") +def priority_rag_pipeline_run_task( + rag_pipeline_invoke_entities_file_id: str, + tenant_id: str, +): + """ + Async Run rag pipeline + :param rag_pipeline_invoke_entities: Rag pipeline invoke entities + rag_pipeline_invoke_entities include: + :param pipeline_id: Pipeline ID + :param user_id: User ID + :param tenant_id: Tenant ID + :param workflow_id: Workflow ID + :param invoke_from: Invoke source (debugger, published, etc.) + :param streaming: Whether to stream results + :param datasource_type: Type of datasource + :param datasource_info: Datasource information dict + :param batch: Batch identifier + :param document_id: Document ID (optional) + :param start_node_id: Starting node ID + :param inputs: Input parameters dict + :param workflow_execution_id: Workflow execution ID + :param workflow_thread_pool_id: Thread pool ID for workflow execution + """ + # run with threading, thread pool size is 10 + + try: + start_at = time.perf_counter() + rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: + # Submit task to thread pool + future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity) + futures.append(future) + + # Wait for all tasks to complete + for future in futures: + try: + future.result() # This will raise any exceptions that occurred in the thread + except Exception: + logging.exception("Error in pipeline task") + end_at = time.perf_counter() + logging.info( + click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green") + ) + except Exception: + logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) + raise + finally: + file_service = FileService(db.engine) + file_service.delete_file(rag_pipeline_invoke_entities_file_id) + db.session.close() + +def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]): + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity + with Session(db.engine) as session: + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant + + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") + + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") + + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) + + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + # Use app context to ensure Flask globals work properly + with current_app.app_context(): + # Set the user directly in g for preserve_flask_contexts + g._login_user = account + + # Copy context for thread (after setting user) + context = contextvars.copy_context() + + # Get Flask app object in the main thread where app context exists + flask_app = current_app._get_current_object() # type: ignore + + # Create a wrapper function that passes user context + def _run_with_user_context(): + # Don't create a new app context here - let _generate handle it + # Just ensure the user is available in contextvars + from core.app.apps.pipeline.pipeline_generator import PipelineGenerator + + pipeline_generator = PipelineGenerator() + pipeline_generator._generate( + flask_app=flask_app, + context=context, + pipeline=pipeline, + workflow_id=workflow_id, + user=account, + application_generate_entity=entity, + invoke_from=InvokeFrom.PUBLISHED, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + # Create and start worker thread + worker_thread = threading.Thread(target=_run_with_user_context) + worker_thread.start() + worker_thread.join() # Wait for worker thread to complete diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 182639d115..d9b6bf5d5a 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -1,8 +1,12 @@ import contextvars +import json import logging import threading import time import uuid +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor +from typing import Any import click from celery import shared_task # type: ignore @@ -10,6 +14,7 @@ from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db @@ -18,21 +23,18 @@ from models.account import Account, Tenant from models.dataset import Pipeline from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom +from services.file_service import FileService @shared_task(queue="pipeline") def rag_pipeline_run_task( - pipeline_id: str, - application_generate_entity: dict, - user_id: str, + rag_pipeline_invoke_entities_file_id: str, tenant_id: str, - workflow_id: str, - streaming: bool, - workflow_execution_id: str | None = None, - workflow_thread_pool_id: str | None = None, ): """ Async Run rag pipeline + :param rag_pipeline_invoke_entities: Rag pipeline invoke entities + rag_pipeline_invoke_entities include: :param pipeline_id: Pipeline ID :param user_id: User ID :param tenant_id: Tenant ID @@ -48,94 +50,137 @@ def rag_pipeline_run_task( :param workflow_execution_id: Workflow execution ID :param workflow_thread_pool_id: Thread pool ID for workflow execution """ - logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green")) - start_at = time.perf_counter() - indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}" + # run with threading, thread pool size is 10 try: - with Session(db.engine) as session: - account = session.query(Account).filter(Account.id == user_id).first() - if not account: - raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - raise ValueError(f"Tenant {tenant_id} not found") - account.current_tenant = tenant - - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() - if not pipeline: - raise ValueError(f"Pipeline {pipeline_id} not found") - - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() - - if not workflow: - raise ValueError(f"Workflow {pipeline.workflow_id} not found") - - if workflow_execution_id is None: - workflow_execution_id = str(uuid.uuid4()) - - # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) - - # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, - ) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, - ) - # Use app context to ensure Flask globals work properly - with current_app.app_context(): - # Set the user directly in g for preserve_flask_contexts - g._login_user = account - - # Copy context for thread (after setting user) - context = contextvars.copy_context() - - # Get Flask app object in the main thread where app context exists - flask_app = current_app._get_current_object() # type: ignore - - # Create a wrapper function that passes user context - def _run_with_user_context(): - # Don't create a new app context here - let _generate handle it - # Just ensure the user is available in contextvars - from core.app.apps.pipeline.pipeline_generator import PipelineGenerator - - pipeline_generator = PipelineGenerator() - pipeline_generator._generate( - flask_app=flask_app, - context=context, - pipeline=pipeline, - workflow_id=workflow_id, - user=account, - application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, - ) - - # Create and start worker thread - worker_thread = threading.Thread(target=_run_with_user_context) - worker_thread.start() - worker_thread.join() # Wait for worker thread to complete - + start_at = time.perf_counter() + rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: + # Submit task to thread pool + future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity) + futures.append(future) + + # Wait for all tasks to complete + for future in futures: + try: + future.result() # This will raise any exceptions that occurred in the thread + except Exception: + logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green") + click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green") ) except Exception: - logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red")) + logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) raise finally: - redis_client.delete(indexing_cache_key) + tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}" + tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}" + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue) + + if next_file_id: + # Process the next waiting task + # Keep the flag set to indicate a task is running + redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) + rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, bytes) else next_file_id, + tenant_id=tenant_id, + ) + else: + # No more waiting tasks, clear the flag + redis_client.delete(tenant_pipeline_task_key) + file_service = FileService(db.engine) + file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() + +def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]): + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity + with Session(db.engine) as session: + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant + + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") + + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") + + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) + + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + # Use app context to ensure Flask globals work properly + with current_app.app_context(): + # Set the user directly in g for preserve_flask_contexts + g._login_user = account + + # Copy context for thread (after setting user) + context = contextvars.copy_context() + + # Get Flask app object in the main thread where app context exists + flask_app = current_app._get_current_object() # type: ignore + + # Create a wrapper function that passes user context + def _run_with_user_context(): + # Don't create a new app context here - let _generate handle it + # Just ensure the user is available in contextvars + from core.app.apps.pipeline.pipeline_generator import PipelineGenerator + + pipeline_generator = PipelineGenerator() + pipeline_generator._generate( + flask_app=flask_app, + context=context, + pipeline=pipeline, + workflow_id=workflow_id, + user=account, + application_generate_entity=entity, + invoke_from=InvokeFrom.PUBLISHED, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + # Create and start worker thread + worker_thread = threading.Thread(target=_run_with_user_context) + worker_thread.start() + worker_thread.join() # Wait for worker thread to complete From 815e5568c3187fd19e14a3f52473a104aacfe631 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Sun, 14 Sep 2025 21:53:32 +0800 Subject: [PATCH 03/12] add dataset service api enable --- .../priority_rag_pipeline_run_task.py | 100 +++++++-------- .../rag_pipeline/rag_pipeline_run_task.py | 115 ++++++++++-------- 2 files changed, 112 insertions(+), 103 deletions(-) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a22d77ec17..3d7f713258 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -55,11 +55,15 @@ def priority_rag_pipeline_run_task( start_at = time.perf_counter() rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + + # Get Flask app object for thread context + flask_app = current_app._get_current_object() # type: ignore + with ThreadPoolExecutor(max_workers=10) as executor: futures = [] for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: - # Submit task to thread pool - future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity) + # Submit task to thread pool with Flask app + future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app) futures.append(future) # Wait for all tasks to complete @@ -80,66 +84,64 @@ def priority_rag_pipeline_run_task( file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() -def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]): - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) - user_id = rag_pipeline_invoke_entity_model.user_id - tenant_id = rag_pipeline_invoke_entity_model.tenant_id - pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id - workflow_id = rag_pipeline_invoke_entity_model.workflow_id - streaming = rag_pipeline_invoke_entity_model.streaming - workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id - workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id - application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: - account = session.query(Account).filter(Account.id == user_id).first() - if not account: - raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - raise ValueError(f"Tenant {tenant_id} not found") - account.current_tenant = tenant +def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): + # Create Flask application context for this thread + with flask_app.app_context(): + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity + with Session(db.engine) as session: + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() - if not pipeline: - raise ValueError(f"Pipeline {pipeline_id} not found") + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() - if not workflow: - raise ValueError(f"Workflow {pipeline.workflow_id} not found") + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") - if workflow_execution_id is None: - workflow_execution_id = str(uuid.uuid4()) + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) - # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) - # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, - ) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, - ) - # Use app context to ensure Flask globals work properly - with current_app.app_context(): + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + # Set the user directly in g for preserve_flask_contexts g._login_user = account # Copy context for thread (after setting user) context = contextvars.copy_context() - # Get Flask app object in the main thread where app context exists - flask_app = current_app._get_current_object() # type: ignore - # Create a wrapper function that passes user context def _run_with_user_context(): # Don't create a new app context here - let _generate handle it diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index d9b6bf5d5a..1af1c9a675 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -54,15 +54,21 @@ def rag_pipeline_run_task( try: start_at = time.perf_counter() - rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) + print("asdadsdadaddadadadadadasdsa") + rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( + rag_pipeline_invoke_entities_file_id) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + + # Get Flask app object for thread context + flask_app = current_app._get_current_object() # type: ignore + with ThreadPoolExecutor(max_workers=10) as executor: futures = [] for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: - # Submit task to thread pool - future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity) + # Submit task to thread pool with Flask app + future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app) futures.append(future) - + # Wait for all tasks to complete for future in futures: try: @@ -71,7 +77,8 @@ def rag_pipeline_run_task( logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green") + click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", + fg="green") ) except Exception: logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) @@ -83,13 +90,14 @@ def rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue) - + if next_file_id: # Process the next waiting task # Keep the flag set to indicate a task is running redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, bytes) else next_file_id, + rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, + bytes) else next_file_id, tenant_id=tenant_id, ) else: @@ -99,66 +107,65 @@ def rag_pipeline_run_task( file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() -def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]): - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) - user_id = rag_pipeline_invoke_entity_model.user_id - tenant_id = rag_pipeline_invoke_entity_model.tenant_id - pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id - workflow_id = rag_pipeline_invoke_entity_model.workflow_id - streaming = rag_pipeline_invoke_entity_model.streaming - workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id - workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id - application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: - account = session.query(Account).filter(Account.id == user_id).first() - if not account: - raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - raise ValueError(f"Tenant {tenant_id} not found") - account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() - if not pipeline: - raise ValueError(f"Pipeline {pipeline_id} not found") +def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): + # Create Flask application context for this thread + with flask_app.app_context(): + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity + with Session(db.engine) as session: + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") - if not workflow: - raise ValueError(f"Workflow {pipeline.workflow_id} not found") + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() - if workflow_execution_id is None: - workflow_execution_id = str(uuid.uuid4()) + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") - # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) - # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, - ) + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, - ) - # Use app context to ensure Flask globals work properly - with current_app.app_context(): # Set the user directly in g for preserve_flask_contexts g._login_user = account # Copy context for thread (after setting user) context = contextvars.copy_context() - # Get Flask app object in the main thread where app context exists - flask_app = current_app._get_current_object() # type: ignore - # Create a wrapper function that passes user context def _run_with_user_context(): # Don't create a new app context here - let _generate handle it From c08a60021a99f330521a136e3c9954ac8234e197 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Sun, 14 Sep 2025 22:06:32 +0800 Subject: [PATCH 04/12] add dataset service api enable --- .../priority_rag_pipeline_run_task.py | 123 +++++++++--------- .../rag_pipeline/rag_pipeline_run_task.py | 112 ++++++++-------- 2 files changed, 116 insertions(+), 119 deletions(-) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 3d7f713258..5ccc51a66a 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -1,7 +1,6 @@ import contextvars import json import logging -import threading import time import uuid from collections.abc import Mapping @@ -85,67 +84,69 @@ def priority_rag_pipeline_run_task( db.session.close() def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): + """Run a single RAG pipeline task within Flask app context.""" # Create Flask application context for this thread with flask_app.app_context(): - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) - user_id = rag_pipeline_invoke_entity_model.user_id - tenant_id = rag_pipeline_invoke_entity_model.tenant_id - pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id - workflow_id = rag_pipeline_invoke_entity_model.workflow_id - streaming = rag_pipeline_invoke_entity_model.streaming - workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id - workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id - application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: - account = session.query(Account).filter(Account.id == user_id).first() - if not account: - raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - raise ValueError(f"Tenant {tenant_id} not found") - account.current_tenant = tenant - - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() - if not pipeline: - raise ValueError(f"Pipeline {pipeline_id} not found") - - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() - - if not workflow: - raise ValueError(f"Workflow {pipeline.workflow_id} not found") - - if workflow_execution_id is None: - workflow_execution_id = str(uuid.uuid4()) - - # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) - - # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, - ) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, - ) + try: + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - # Set the user directly in g for preserve_flask_contexts - g._login_user = account + with Session(db.engine) as session: + # Load required entities + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant - # Copy context for thread (after setting user) - context = contextvars.copy_context() + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") - # Create a wrapper function that passes user context - def _run_with_user_context(): - # Don't create a new app context here - let _generate handle it - # Just ensure the user is available in contextvars + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") + + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) + + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) + + # Create workflow repositories + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + + # Set the user directly in g for preserve_flask_contexts + g._login_user = account + + # Copy context for passing to pipeline generator + context = contextvars.copy_context() + + # Direct execution without creating another thread + # Since we're already in a thread pool, no need for nested threading from core.app.apps.pipeline.pipeline_generator import PipelineGenerator pipeline_generator = PipelineGenerator() @@ -162,8 +163,6 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, ) - - # Create and start worker thread - worker_thread = threading.Thread(target=_run_with_user_context) - worker_thread.start() - worker_thread.join() # Wait for worker thread to complete + except Exception: + logging.exception("Error in priority pipeline task") + raise diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 1af1c9a675..5d64177e7e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -1,7 +1,6 @@ import contextvars import json import logging -import threading import time import uuid from collections.abc import Mapping @@ -54,7 +53,6 @@ def rag_pipeline_run_task( try: start_at = time.perf_counter() - print("asdadsdadaddadadadadadasdsa") rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( rag_pipeline_invoke_entities_file_id) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) @@ -109,67 +107,69 @@ def rag_pipeline_run_task( def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): + """Run a single RAG pipeline task within Flask app context.""" # Create Flask application context for this thread with flask_app.app_context(): - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) - user_id = rag_pipeline_invoke_entity_model.user_id - tenant_id = rag_pipeline_invoke_entity_model.tenant_id - pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id - workflow_id = rag_pipeline_invoke_entity_model.workflow_id - streaming = rag_pipeline_invoke_entity_model.streaming - workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id - workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id - application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: - account = session.query(Account).filter(Account.id == user_id).first() - if not account: - raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - raise ValueError(f"Tenant {tenant_id} not found") - account.current_tenant = tenant + try: + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + user_id = rag_pipeline_invoke_entity_model.user_id + tenant_id = rag_pipeline_invoke_entity_model.tenant_id + pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id + workflow_id = rag_pipeline_invoke_entity_model.workflow_id + streaming = rag_pipeline_invoke_entity_model.streaming + workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id + workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id + application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity + + with Session(db.engine) as session: + # Load required entities + account = session.query(Account).filter(Account.id == user_id).first() + if not account: + raise ValueError(f"Account {user_id} not found") + + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + raise ValueError(f"Tenant {tenant_id} not found") + account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() - if not pipeline: - raise ValueError(f"Pipeline {pipeline_id} not found") + pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + if not workflow: + raise ValueError(f"Workflow {pipeline.workflow_id} not found") - if not workflow: - raise ValueError(f"Workflow {pipeline.workflow_id} not found") + if workflow_execution_id is None: + workflow_execution_id = str(uuid.uuid4()) - if workflow_execution_id is None: - workflow_execution_id = str(uuid.uuid4()) + # Create application generate entity from dict + entity = RagPipelineGenerateEntity(**application_generate_entity) - # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + # Create workflow repositories + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, + ) - # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, - ) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=account, + app_id=entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=account, - app_id=entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, - ) + # Set the user directly in g for preserve_flask_contexts + g._login_user = account - # Set the user directly in g for preserve_flask_contexts - g._login_user = account + # Copy context for passing to pipeline generator + context = contextvars.copy_context() - # Copy context for thread (after setting user) - context = contextvars.copy_context() - - # Create a wrapper function that passes user context - def _run_with_user_context(): - # Don't create a new app context here - let _generate handle it - # Just ensure the user is available in contextvars + # Direct execution without creating another thread + # Since we're already in a thread pool, no need for nested threading from core.app.apps.pipeline.pipeline_generator import PipelineGenerator pipeline_generator = PipelineGenerator() @@ -186,8 +186,6 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, ) - - # Create and start worker thread - worker_thread = threading.Thread(target=_run_with_user_context) - worker_thread.start() - worker_thread.join() # Wait for worker thread to complete + except Exception: + logging.exception("Error in pipeline task") + raise From 7eb8259e3dc4026e3062abaa70cc64f2802908c4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 15 Sep 2025 11:44:13 +0800 Subject: [PATCH 05/12] fix priority task --- .../app/apps/pipeline/pipeline_generator.py | 1 + api/services/datasource_provider_service.py | 54 +++++++++---------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 34a27b954c..fdfceeb148 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -266,6 +266,7 @@ class PipelineGenerator(BaseAppGenerator): else: priority_rag_pipeline_run_task.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file.id, + tenant_id=dataset.tenant_id, ) # return batch, dataset, documents diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 41884661b2..1b5077df7b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -345,19 +345,18 @@ class DatasourceProviderService: def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: """ check if tenant oauth params is enabled - """ - with Session(db.engine).no_autoflush as session: - return ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - enabled=True, - ) - .count() - > 0 + """ + return ( + db.session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + enabled=True, ) + .count() + > 0 + ) def get_tenant_oauth_client( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False @@ -365,23 +364,22 @@ class DatasourceProviderService: """ get tenant oauth client """ - with Session(db.engine).no_autoflush as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ) - .first() + tenant_oauth_client_params = ( + db.session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, ) - if tenant_oauth_client_params: - encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) - if mask: - return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) - else: - return encrypter.decrypt(tenant_oauth_client_params.client_params) - return None + .first() + ) + if tenant_oauth_client_params: + encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) + if mask: + return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) + else: + return encrypter.decrypt(tenant_oauth_client_params.client_params) + return None def get_oauth_encrypter( self, tenant_id: str, datasource_provider_id: DatasourceProviderID From 70a362ed3b8b20e049e225f22184c89bb1e09886 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 15 Sep 2025 18:52:01 +0800 Subject: [PATCH 06/12] fix priority task --- api/services/dataset_service.py | 89 +++++++++++++++++++++++ api/services/rag_pipeline/rag_pipeline.py | 3 +- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f0157db0f9..5ab71d2c20 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,6 +9,7 @@ from collections import Counter from typing import Any, Literal, Optional import sqlalchemy as sa +import yaml from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -46,6 +47,7 @@ from models.dataset import ( ) from models.model import UploadFile from models.provider_ids import ModelProviderID +from models.workflow import Workflow from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -56,6 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, + RetrievalSetting, ) from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -64,6 +67,7 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService +from services.rag_pipeline.rag_pipeline import RagPipelineService from services.tag_service import TagService from services.vector_service import VectorService from tasks.add_document_to_index_task import add_document_to_index_task @@ -523,12 +527,97 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # update pipeline knowledge base node data + DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) + # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) return dataset + @staticmethod + def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str): + """ + Update pipeline knowledge base node data. + """ + if dataset.runtime_mode != "rag_pipeline": + return + + pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() + if not pipeline: + return + + try: + rag_pipeline_service = RagPipelineService() + published_workflow = rag_pipeline_service.get_published_workflow(pipeline) + draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline) + + # update knowledge nodes + def update_knowledge_nodes(workflow_graph: str) -> str: + """Update knowledge-index nodes in workflow graph.""" + data: dict[str, Any] = json.loads(workflow_graph) + + nodes = data.get("nodes", []) + updated = False + + for node in nodes: + if node.get("data", {}).get("type") == "knowledge-index": + try: + knowledge_index_node_data = node.get("data", {}) + knowledge_index_node_data["embedding_model"] = dataset.embedding_model + knowledge_index_node_data["embedding_model_provider"] = dataset.embedding_model_provider + knowledge_index_node_data["retrieval_model"] = dataset.retrieval_model + knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure + knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] + knowledge_index_node_data["keyword_number"] = dataset.keyword_number + node["data"] = knowledge_index_node_data + updated = True + except Exception: + logging.exception("Failed to update knowledge node") + continue + + if updated: + data["nodes"] = nodes + return json.dumps(data) + return workflow_graph + + # Update published workflow + if published_workflow: + updated_graph = update_knowledge_nodes(published_workflow.graph) + if updated_graph != published_workflow.graph: + # Create new workflow version + workflow = Workflow.new( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + type=published_workflow.type, + version=str(datetime.datetime.now(datetime.UTC).replace(tzinfo=None)), + graph=updated_graph, + features=published_workflow.features, + created_by=updata_user_id, + environment_variables=published_workflow.environment_variables, + conversation_variables=published_workflow.conversation_variables, + rag_pipeline_variables=published_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", + ) + db.session.add(workflow) + + # Update draft workflow + if draft_workflow: + updated_graph = update_knowledge_nodes(draft_workflow.graph) + if updated_graph != draft_workflow.graph: + draft_workflow.graph = updated_graph + db.session.add(draft_workflow) + + # Commit all changes in one transaction + db.session.commit() + + except Exception: + logging.exception("Failed to update pipeline knowledge base node data") + db.session.rollback() + raise + @staticmethod def _handle_indexing_technique_change(dataset, data, filtered_data): """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e28aa02593..0b43404b3d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -65,7 +65,6 @@ from models.workflow import ( WorkflowType, ) from repositories.factory import DifyAPIRepositoryFactory -from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, @@ -346,6 +345,8 @@ class RagPipelineService: graph = workflow.graph_dict nodes = graph.get("nodes", []) + from services.dataset_service import DatasetService + for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = node.get("data", {}) From 8346506978f84b2b2f6f9d17e27ab108e8e2553a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 14:14:09 +0800 Subject: [PATCH 07/12] fix document retry --- .../service_api/dataset/document.py | 6 ++- .../app/apps/pipeline/pipeline_generator.py | 12 ++++-- api/services/dataset_service.py | 22 +++++------ api/services/file_service.py | 12 +++--- api/services/rag_pipeline/rag_pipeline.py | 39 ++++++++++++++++++- api/tasks/retry_document_indexing_task.py | 28 ++++++++++--- 6 files changed, 87 insertions(+), 32 deletions(-) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 49a6ea7b5f..4bce64e0a1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -123,7 +123,8 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), + text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -201,7 +202,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") - upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), + text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index fdfceeb148..8751197767 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -48,7 +48,6 @@ from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFro from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode -from services.dataset_service import DocumentService from services.datasource_provider_service import DatasourceProviderService from services.feature_service import FeatureService from services.file_service import FileService @@ -72,6 +71,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload @@ -86,6 +86,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Mapping[str, Any]: ... @overload @@ -100,6 +101,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -113,6 +115,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, + is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -132,7 +135,8 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents = [] - if invoke_from == InvokeFrom.PUBLISHED: + if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + from services.dataset_service import DocumentService for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( @@ -156,7 +160,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = None - if invoke_from == InvokeFrom.PUBLISHED: + if invoke_from == InvokeFrom.PUBLISHED and not is_retry: document_id = documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, @@ -246,7 +250,7 @@ class PipelineGenerator(BaseAppGenerator): name = "rag_pipeline_invoke_entities.json" # Convert list to proper JSON string json_text = json.dumps(text) - upload_file = FileService(db.engine).upload_text(json_text, name) + upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) features = FeatureService.get_features(dataset.tenant_id) if features.billing.subscription.plan == "sandbox": tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 5ab71d2c20..03757fe4a5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -543,24 +543,24 @@ class DatasetService: """ if dataset.runtime_mode != "rag_pipeline": return - + pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() if not pipeline: return - + try: rag_pipeline_service = RagPipelineService() published_workflow = rag_pipeline_service.get_published_workflow(pipeline) draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline) - + # update knowledge nodes def update_knowledge_nodes(workflow_graph: str) -> str: """Update knowledge-index nodes in workflow graph.""" data: dict[str, Any] = json.loads(workflow_graph) - + nodes = data.get("nodes", []) updated = False - + for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": try: @@ -576,12 +576,12 @@ class DatasetService: except Exception: logging.exception("Failed to update knowledge node") continue - + if updated: data["nodes"] = nodes return json.dumps(data) return workflow_graph - + # Update published workflow if published_workflow: updated_graph = update_knowledge_nodes(published_workflow.graph) @@ -602,17 +602,17 @@ class DatasetService: marked_comment="", ) db.session.add(workflow) - + # Update draft workflow if draft_workflow: updated_graph = update_knowledge_nodes(draft_workflow.graph) if updated_graph != draft_workflow.graph: draft_workflow.graph = updated_graph db.session.add(draft_workflow) - + # Commit all changes in one transaction db.session.commit() - + except Exception: logging.exception("Failed to update pipeline knowledge base node data") db.session.rollback() @@ -1360,7 +1360,7 @@ class DocumentService: redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task document_ids = [document.id for document in documents] - retry_document_indexing_task.delay(dataset_id, document_ids) + retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) @staticmethod def sync_website_document(dataset_id: str, document: Document): diff --git a/api/services/file_service.py b/api/services/file_service.py index 894b485cce..f9d4eb5686 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -120,33 +120,31 @@ class FileService: return file_size <= file_size_limit - def upload_text(self, text: str, text_name: str) -> UploadFile: - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" + file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt" # save file to storage storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, storage_type=dify_config.STORAGE_TYPE, key=file_key, name=text_name, size=len(text), extension="txt", mime_type="text/plain", - created_by=current_user.id, + created_by=user_id, created_by_role=CreatorUserRole.ACCOUNT, created_at=naive_utc_now(), used=True, - used_by=current_user.id, + used_by=user_id, used_at=naive_utc_now(), ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0b43404b3d..a9aca31439 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,8 +5,9 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast from uuid import uuid4 +import uuid from flask_login import current_user from sqlalchemy import func, or_, select @@ -14,6 +15,8 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceMessage, @@ -54,7 +57,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore +from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -1312,3 +1315,35 @@ class RagPipelineService: "installed_recommended_plugins": installed_plugin_list, "uninstalled_recommended_plugins": uninstalled_plugin_list, } + + def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + """ + Retry error document + """ + document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter( + DocumentPipelineExecutionLog.document_id == document.id).first() + if not document_pipeline_excution_log: + raise ValueError("Document pipeline execution log not found") + pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + # convert to app config + workflow = self.get_published_workflow(pipeline) + if not workflow: + raise ValueError("Workflow not found") + PipelineGenerator().generate( + pipeline=pipeline, + workflow=workflow, + user=user, + args={ + "inputs": document_pipeline_excution_log.input_data, + "start_node_id": document_pipeline_excution_log.datasource_node_id, + "datasource_type": document_pipeline_excution_log.datasource_type, + "datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)], + }, + invoke_from=InvokeFrom.PUBLISHED, + streaming=False, + call_depth=0, + workflow_thread_pool_id=None, + is_retry=True, + ) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index c52218caae..1899f93ff7 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -9,32 +9,44 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService +from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): +def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str): """ Async process document :param dataset_id: :param document_ids: + :param user_id: - Usage: retry_document_indexing_task.delay(dataset_id, document_ids) + Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id) """ start_at = time.perf_counter() + print("sadaddadadaaaadadadadsdsadasdadasdasda") try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) return - tenant_id = dataset.tenant_id + user = db.session.query(Account).where(Account.id == user_id).first() + if not user: + logger.info(click.style(f"User not found: {user_id}", fg="red")) + return + tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first() + if not tenant: + raise ValueError("Tenant not found") + user.current_tenant = tenant + for document_id in document_ids: retry_indexing_cache_key = f"document_{document_id}_is_retried" # check document limit - features = FeatureService.get_features(tenant_id) + features = FeatureService.get_features(tenant.id) try: if features.billing.enabled: vector_space = features.vector_space @@ -84,8 +96,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.add(document) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) + if dataset.runtime_mode == "rag_pipeline": + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.retry_error_document(dataset, document, user) + else: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" From c4ddc6420aff9251170e8c2b34b146fc01c36997 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 14:18:26 +0800 Subject: [PATCH 08/12] fix document retry --- api/fields/dataset_fields.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index f639fb2ea9..73002b6736 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -95,6 +95,7 @@ dataset_detail_fields = { "is_published": fields.Boolean, "total_documents": fields.Integer, "total_available_documents": fields.Integer, + "enable_api": fields.Boolean, } dataset_query_detail_fields = { From c463f31f560abc76420af9684549dd6b6ba24298 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 14:52:33 +0800 Subject: [PATCH 09/12] fix document retry --- api/core/app/apps/pipeline/pipeline_generator.py | 10 +++++++--- api/services/rag_pipeline/rag_pipeline.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 8751197767..0e13599c30 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -72,6 +72,7 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, + document_id: Optional[str] = None, ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload @@ -87,6 +88,7 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, + document_id: Optional[str] = None, ) -> Mapping[str, Any]: ... @overload @@ -102,6 +104,7 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, + document_id: Optional[str] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -116,6 +119,7 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, is_retry: bool = False, + documents: list[Document] = [], ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -134,7 +138,6 @@ class PipelineGenerator(BaseAppGenerator): pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) - documents = [] if invoke_from == InvokeFrom.PUBLISHED and not is_retry: from services.dataset_service import DocumentService for datasource_info in datasource_info_list: @@ -160,8 +163,9 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = None - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + if documents: document_id = documents[i].id + if invoke_from == InvokeFrom.PUBLISHED and not is_retry: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, datasource_type=datasource_type, @@ -218,7 +222,7 @@ class PipelineGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, ) - if invoke_from == InvokeFrom.DEBUGGER: + if invoke_from == InvokeFrom.DEBUGGER or is_retry: return self._generate( flask_app=current_app._get_current_object(), # type: ignore context=contextvars.copy_context(), diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a9aca31439..f9ef050c52 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1346,4 +1346,5 @@ class RagPipelineService: call_depth=0, workflow_thread_pool_id=None, is_retry=True, + documents=[document], ) From 610f0414dbfab7081a7d3d8877da2a4769f9df47 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 15:29:19 +0800 Subject: [PATCH 10/12] fix document retry --- api/controllers/console/datasets/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 74fb07f897..91e2a90e5e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -739,7 +739,7 @@ class DatasetApiDeleteApi(Resource): db.session.commit() return {"result": "success"}, 204 - +@console_ns.route("/datasets//api-keys/") class DatasetEnableApiApi(Resource): @setup_required @login_required From 05aec664246be2e5a27044a4a4865eaae369822a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 16:05:01 +0800 Subject: [PATCH 11/12] fix re-chunk document --- .../rag_pipeline/rag_pipeline_workflow.py | 66 +++++++------------ .../app/apps/pipeline/pipeline_generator.py | 13 ++-- api/core/app/apps/pipeline/pipeline_runner.py | 1 + api/core/app/entities/app_invoke_entities.py | 1 + api/core/workflow/enums.py | 1 + .../knowledge_index/knowledge_index_node.py | 17 ++++- api/core/workflow/system_variable.py | 3 + 7 files changed, 50 insertions(+), 52 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c70343ec95..d00be3a573 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource): Get draft rag pipeline's workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() # fetch draft workflow by app_model @@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource): Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() content_type = request.headers.get("Content-Type", "") @@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource): else: abort(415) - if not isinstance(current_user, Account): - raise Forbidden() - try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ @@ -161,10 +158,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -198,10 +192,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -235,10 +226,7 @@ class DraftRagPipelineRunApi(Resource): Run draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -272,10 +260,7 @@ class PublishedRagPipelineRunApi(Resource): Run published workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -285,6 +270,7 @@ class PublishedRagPipelineRunApi(Resource): parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") + parser.add_argument("original_document_id", type=str, required=False, location="json") args = parser.parse_args() streaming = args["response_mode"] == "streaming" @@ -394,10 +380,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -439,7 +422,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -482,7 +465,7 @@ class RagPipelineDraftNodeRunApi(Resource): Run draft workflow node """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -514,7 +497,7 @@ class RagPipelineTaskStopApi(Resource): Stop workflow task """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -533,7 +516,7 @@ class PublishedRagPipelineApi(Resource): Get published pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() if not pipeline.is_published: return None @@ -553,7 +536,7 @@ class PublishedRagPipelineApi(Resource): Publish workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() rag_pipeline_service = RagPipelineService() @@ -587,7 +570,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() # Get default block configs @@ -605,10 +588,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -651,7 +631,7 @@ class PublishedAllRagPipelineApi(Resource): """ Get published workflows """ - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -700,7 +680,7 @@ class RagPipelineByIdApi(Resource): Update workflow attributes """ # Check permission - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -756,7 +736,7 @@ class PublishedRagPipelineSecondStepApi(Resource): Get second step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -781,7 +761,7 @@ class PublishedRagPipelineFirstStepApi(Resource): Get first step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -806,7 +786,7 @@ class DraftRagPipelineFirstStepApi(Resource): Get first step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -831,7 +811,7 @@ class DraftRagPipelineSecondStepApi(Resource): Get second step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -953,7 +933,7 @@ class RagPipelineTransformApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - if not (current_user.is_editor or current_user.is_dataset_operator): + if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() dataset_id = str(dataset_id) @@ -972,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource): """ Set datasource variables """ - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 0e13599c30..c9daead0ba 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -72,7 +72,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, - document_id: Optional[str] = None, ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload @@ -88,7 +87,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, - document_id: Optional[str] = None, ) -> Mapping[str, Any]: ... @overload @@ -104,7 +102,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, - document_id: Optional[str] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -119,7 +116,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, is_retry: bool = False, - documents: list[Document] = [], ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -138,7 +134,8 @@ class PipelineGenerator(BaseAppGenerator): pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + documents: list[Document] = [] + if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) @@ -162,10 +159,9 @@ class PipelineGenerator(BaseAppGenerator): rag_pipeline_invoke_entities = [] for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) - document_id = None - if documents: - document_id = documents[i].id + document_id = args.get("original_document_id") or None if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + document_id = document_id or documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, datasource_type=datasource_type, @@ -184,6 +180,7 @@ class PipelineGenerator(BaseAppGenerator): datasource_type=datasource_type, datasource_info=datasource_info, dataset_id=dataset.id, + original_document_id=args.get("original_document_id"), start_node_id=start_node_id, batch=batch, document_id=document_id, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 12506049aa..f2f01d1ee7 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -122,6 +122,7 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, document_id=self.application_generate_entity.document_id, + original_document_id=self.application_generate_entity.original_document_id, batch=self.application_generate_entity.batch, dataset_id=self.application_generate_entity.dataset_id, datasource_type=self.application_generate_entity.datasource_type, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6ed596bfb8..1c055fe8b6 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -257,6 +257,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): dataset_id: str batch: str document_id: Optional[str] = None + original_document_id: Optional[str] = None start_node_id: Optional[str] = None class SingleIterationRunEntity(BaseModel): diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index c5be9be02a..00a125660a 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -24,6 +24,7 @@ class SystemVariableKey(StrEnum): WORKFLOW_EXECUTION_ID = "workflow_run_id" # RAG Pipeline DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" BATCH = "batch" DATASET_ID = "dataset_id" DATASOURCE_TYPE = "datasource_type" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 850ea4a9cf..d970d7480c 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -4,7 +4,7 @@ import time from collections.abc import Mapping from typing import Any, Optional, cast -from sqlalchemy import func +from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -128,6 +128,8 @@ class KnowledgeIndexNode(Node): document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if not document_id: raise KnowledgeIndexNodeError("Document ID is required.") + original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) + batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) if not batch: raise KnowledgeIndexNodeError("Batch is required.") @@ -137,6 +139,19 @@ class KnowledgeIndexNode(Node): # chunk nodes by chunk size indexing_start_at = time.perf_counter() index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() + if original_document_id: + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() index_processor.index(dataset, document, chunks) indexing_end_at = time.perf_counter() document.indexing_latency = indexing_end_at - indexing_start_at diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index cd3388de7e..6716e745cd 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -44,6 +44,7 @@ class SystemVariable(BaseModel): conversation_id: str | None = None dialogue_count: int | None = None document_id: str | None = None + original_document_id: str | None = None dataset_id: str | None = None batch: str | None = None datasource_type: str | None = None @@ -94,6 +95,8 @@ class SystemVariable(BaseModel): d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count if self.document_id is not None: d[SystemVariableKey.DOCUMENT_ID] = self.document_id + if self.original_document_id is not None: + d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id if self.dataset_id is not None: d[SystemVariableKey.DATASET_ID] = self.dataset_id if self.batch is not None: From 0ec037b803eba58e8031d9c57145ab27fbe120a7 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 16 Sep 2025 16:08:04 +0800 Subject: [PATCH 12/12] dev/reformat --- api/controllers/console/datasets/datasets.py | 2 ++ .../service_api/dataset/document.py | 10 ++++--- .../rag_pipeline/rag_pipeline_workflow.py | 29 +++++++++++-------- api/controllers/service_api/wraps.py | 6 ++-- .../app/apps/pipeline/pipeline_generator.py | 23 ++++++++------- .../entities/rag_pipeline_invoke_entities.py | 2 +- .../rag/datasource/keyword/jieba/jieba.py | 12 ++------ api/services/dataset_service.py | 3 -- api/services/datasource_provider_service.py | 2 +- api/services/file_service.py | 4 +-- .../entity/pipeline_service_api_entities.py | 7 +++-- api/services/rag_pipeline/rag_pipeline.py | 18 ++++++++---- .../priority_rag_pipeline_run_task.py | 19 +++++++----- .../rag_pipeline/rag_pipeline_run_task.py | 17 ++++++----- 14 files changed, 87 insertions(+), 67 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 91e2a90e5e..3a13bb6b67 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -739,6 +739,8 @@ class DatasetApiDeleteApi(Resource): db.session.commit() return {"result": "success"}, 204 + + @console_ns.route("/datasets//api-keys/") class DatasetEnableApiApi(Resource): @setup_required diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 6b635bcfbd..a9f7608733 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -124,8 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService(db.engine).upload_text(text=str(text), - text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -203,8 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") - upload_file = FileService(db.engine).upload_text(text=str(text), - text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 75e2ba9f63..55bfdde009 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -41,7 +41,7 @@ class DatasourcePluginsApi(DatasetApiResource): @service_api_ns.doc( params={ "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)" + "(true for published, false for draft, default: true)" } ) @service_api_ns.doc( @@ -54,15 +54,14 @@ class DatasourcePluginsApi(DatasetApiResource): """Resource for getting datasource plugins.""" # Get query parameter to determine published or draft is_published: bool = request.args.get("is_published", default=True, type=bool) - + rag_pipeline_service: RagPipelineService = RagPipelineService() datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( - tenant_id=tenant_id, - dataset_id=dataset_id, - is_published=is_published + tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published ) return datasource_plugins, 200 + @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run") class DatasourceNodeRunApi(DatasetApiResource): """Resource for datasource node run.""" @@ -80,7 +79,7 @@ class DatasourceNodeRunApi(DatasetApiResource): "datasource_type": "Datasource type, e.g. online_document", "credential_id": "Credential ID", "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)" + "(true for published, false for draft, default: true)", } ) @service_api_ns.doc( @@ -136,8 +135,8 @@ class PipelineRunApi(DatasetApiResource): "datasource_info_list": "Datasource info list", "start_node_id": "Start node ID", "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)", - "streaming": "Whether to stream the response(streaming or blocking), default: streaming" + "(true for published, false for draft, default: true)", + "streaming": "Whether to stream the response(streaming or blocking), default: streaming", } ) @service_api_ns.doc( @@ -154,9 +153,16 @@ class PipelineRunApi(DatasetApiResource): parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_published", type=bool, required=True, default=True, location="json") - parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json") + parser.add_argument( + "response_mode", + type=str, + required=True, + choices=["streaming", "blocking"], + default="blocking", + location="json", + ) args: ParseResult = parser.parse_args() - + if not isinstance(current_user, Account): raise Forbidden() @@ -173,7 +179,7 @@ class PipelineRunApi(DatasetApiResource): return helper.compact_generate_response(response) except Exception as ex: - raise PipelineRunError(description=str(ex)) + raise PipelineRunError(description=str(ex)) @service_api_ns.route("/datasets/pipeline/file-upload") @@ -189,7 +195,6 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): 401: "Unauthorized - invalid API token", 413: "File too large", 415: "Unsupported file type", - } ) def post(self, tenant_id: str): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 246d3750d1..ee8e1d105b 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -204,7 +204,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): if not dataset_id and args: # For class methods: args[0] is self, args[1] is dataset_id (if exists) # Check if first arg is likely a class instance (has __dict__ or __class__) - if len(args) > 1 and hasattr(args[0], '__dict__'): + if len(args) > 1 and hasattr(args[0], "__dict__"): # This is a class method, dataset_id should be in args[1] potential_id = args[1] # Validate it's a string-like UUID, not another object @@ -212,7 +212,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): # Try to convert to string and check if it's a valid UUID format str_id = str(potential_id) # Basic check: UUIDs are 36 chars with hyphens - if len(str_id) == 36 and str_id.count('-') == 4: + if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id except: pass @@ -221,7 +221,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): potential_id = args[0] try: str_id = str(potential_id) - if len(str_id) == 36 and str_id.count('-') == 4: + if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id except: pass diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c9daead0ba..574d9c71bf 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -137,6 +137,7 @@ class PipelineGenerator(BaseAppGenerator): documents: list[Document] = [] if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService + for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( @@ -234,16 +235,18 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id=workflow_thread_pool_id, ) else: - rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity( - pipeline_id=pipeline.id, - user_id=user.id, - tenant_id=pipeline.tenant_id, - workflow_id=workflow.id, - streaming=streaming, - workflow_execution_id=workflow_run_id, - workflow_thread_pool_id=workflow_thread_pool_id, - application_generate_entity=application_generate_entity.model_dump(), - )) + rag_pipeline_invoke_entities.append( + RagPipelineInvokeEntity( + pipeline_id=pipeline.id, + user_id=user.id, + tenant_id=pipeline.tenant_id, + workflow_id=workflow.id, + streaming=streaming, + workflow_execution_id=workflow_run_id, + workflow_thread_pool_id=workflow_thread_pool_id, + application_generate_entity=application_generate_entity.model_dump(), + ) + ) if rag_pipeline_invoke_entities: # store the rag_pipeline_invoke_entities to object storage diff --git a/api/core/app/entities/rag_pipeline_invoke_entities.py b/api/core/app/entities/rag_pipeline_invoke_entities.py index b26f496c8a..992b8da893 100644 --- a/api/core/app/entities/rag_pipeline_invoke_entities.py +++ b/api/core/app/entities/rag_pipeline_invoke_entities.py @@ -11,4 +11,4 @@ class RagPipelineInvokeEntity(BaseModel): workflow_id: str streaming: bool workflow_execution_id: str | None = None - workflow_thread_pool_id: str | None = None \ No newline at end of file + workflow_thread_pool_id: str | None = None diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 3d69d86b65..97052717db 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -29,9 +29,7 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk for text in texts: keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) @@ -52,9 +50,7 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk for i in range(len(texts)): text = texts[i] if keywords_list: @@ -239,9 +235,7 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d0d3c2d426..ed2301e172 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,7 +10,6 @@ from collections.abc import Sequence from typing import Any, Literal import sqlalchemy as sa -import yaml from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -60,7 +59,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, - RetrievalSetting, ) from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -1020,7 +1018,6 @@ class DatasetService: dataset.updated_at = naive_utc_now() db.session.commit() - @staticmethod def get_dataset_auto_disable_logs(dataset_id: str): assert isinstance(current_user, Account) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 1b5077df7b..870360ceb6 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -345,7 +345,7 @@ class DatasourceProviderService: def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: """ check if tenant oauth params is enabled - """ + """ return ( db.session.query(DatasourceOauthTenantParamConfig) .filter_by( diff --git a/api/services/file_service.py b/api/services/file_service.py index 5708efba3c..f0bb68766d 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id -from libs.login import current_user from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -121,7 +120,6 @@ class FileService: return file_size <= file_size_limit def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: - if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name @@ -241,4 +239,4 @@ class FileService: return storage.delete(upload_file.key) session.delete(upload_file) - session.commit() \ No newline at end of file + session.commit() diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py index f1718e3cc8..35005fad71 100644 --- a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional + from pydantic import BaseModel @@ -10,10 +12,11 @@ class DatasourceNodeRunApiEntity(BaseModel): credential_id: Optional[str] = None is_published: bool + class PipelineRunApiEntity(BaseModel): inputs: Mapping[str, Any] datasource_type: str datasource_info_list: list[Mapping[str, Any]] start_node_id: str is_published: bool - response_mode: str \ No newline at end of file + response_mode: str diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f9ef050c52..4f97e0f9bc 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 -import uuid from flask_login import current_user from sqlalchemy import func, or_, select @@ -15,7 +14,6 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config -from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( @@ -57,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore +from models.dataset import ( # type: ignore + Dataset, + Document, + DocumentPipelineExecutionLog, + Pipeline, + PipelineCustomizedTemplate, + PipelineRecommendedPlugin, +) from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -1320,8 +1325,11 @@ class RagPipelineService: """ Retry error document """ - document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter( - DocumentPipelineExecutionLog.document_id == document.id).first() + document_pipeline_excution_log = ( + db.session.query(DocumentPipelineExecutionLog) + .filter(DocumentPipelineExecutionLog.document_id == document.id) + .first() + ) if not document_pipeline_excution_log: raise ValueError("Document pipeline execution log not found") pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 5ccc51a66a..7021ddab38 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -52,19 +52,21 @@ def priority_rag_pipeline_run_task( try: start_at = time.perf_counter() - rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( + rag_pipeline_invoke_entities_file_id + ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) - + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore - + with ThreadPoolExecutor(max_workers=10) as executor: futures = [] for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: # Submit task to thread pool with Flask app future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app) futures.append(future) - + # Wait for all tasks to complete for future in futures: try: @@ -73,7 +75,9 @@ def priority_rag_pipeline_run_task( logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green") + click.style( + f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + ) ) except Exception: logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) @@ -83,6 +87,7 @@ def priority_rag_pipeline_run_task( file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() + def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): """Run a single RAG pipeline task within Flask app context.""" # Create Flask application context for this thread @@ -97,13 +102,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - + with Session(db.engine) as session: # Load required entities account = session.query(Account).filter(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 5d64177e7e..d71a305b14 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -54,7 +54,8 @@ def rag_pipeline_run_task( try: start_at = time.perf_counter() rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( - rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities_file_id + ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) # Get Flask app object for thread context @@ -75,8 +76,9 @@ def rag_pipeline_run_task( logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", - fg="green") + click.style( + f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + ) ) except Exception: logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) @@ -94,8 +96,9 @@ def rag_pipeline_run_task( # Keep the flag set to indicate a task is running redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, - bytes) else next_file_id, + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, tenant_id=tenant_id, ) else: @@ -120,13 +123,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - + with Session(db.engine) as session: # Load required entities account = session.query(Account).filter(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found")