Feat/add dataset service api enable (#25765)

This commit is contained in:
Jyong 2025-09-16 16:09:57 +08:00 committed by GitHub
commit 1bf0dbc5d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1011 additions and 205 deletions

View File

@ -741,6 +741,19 @@ class DatasetApiDeleteApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<str:status>")
class DatasetEnableApiApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, status):
dataset_id_str = str(dataset_id)
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
return {"result": "success"}, 200
@console_ns.route("/datasets/api-base-info")
class DatasetApiBaseUrlApi(Resource):
@api.doc("get_dataset_api_base_info")

View File

@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
content_type = request.headers.get("Content-Type", "")
@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource):
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
@ -161,10 +158,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -198,10 +192,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -235,10 +226,7 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -272,10 +260,7 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -285,6 +270,7 @@ class PublishedRagPipelineRunApi(Resource):
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
parser.add_argument("original_document_id", type=str, required=False, location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
@ -394,10 +380,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -439,7 +422,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -482,7 +465,7 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -514,7 +497,7 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -533,7 +516,7 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published:
return None
@ -553,7 +536,7 @@ class PublishedRagPipelineApi(Resource):
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService()
@ -587,7 +570,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
@ -605,10 +588,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -651,7 +631,7 @@ class PublishedAllRagPipelineApi(Resource):
"""
Get published workflows
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -700,7 +680,7 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes
"""
# Check permission
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -756,7 +736,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@ -781,7 +761,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@ -806,7 +786,7 @@ class DraftRagPipelineFirstStepApi(Resource):
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@ -831,7 +811,7 @@ class DraftRagPipelineSecondStepApi(Resource):
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@ -953,7 +933,7 @@ class RagPipelineTransformApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
dataset_id = str(dataset_id)
@ -972,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
"""
Set datasource variables
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -124,7 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -202,7 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},

View File

@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
return 204
@service_api_ns.route("/datasets/metadata/built-in")
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
@service_api_ns.doc("get_built_in_fields")
@service_api_ns.doc(description="Get all built-in metadata fields")
@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id):
def get(self, tenant_id, dataset_id):
"""Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200

View File

@ -0,0 +1,239 @@
import string
import uuid
from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
params={
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
}
)
@service_api_ns.doc(
responses={
200: "Datasource plugins retrieved successfully",
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
)
return datasource_plugins, 200
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"credential_id": "Credential ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
}
)
@service_api_ns.doc(
responses={
200: "Datasource node run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=datasource_node_run_api_entity.inputs,
account=current_user,
datasource_type=datasource_node_run_api_entity.datasource_type,
is_published=datasource_node_run_api_entity.is_published,
credential_id=datasource_node_run_api_entity.credential_id,
)
)
)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"datasource_info_list": "Datasource info list",
"start_node_id": "Start node ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
}
)
@service_api_ns.doc(
responses={
200: "Pipeline run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
parser.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
)
return helper.compact_generate_response(response)
except Exception as ex:
raise PipelineRunError(description=str(ex))
@service_api_ns.route("/datasets/pipeline/file-upload")
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
"""Resource for uploading a file to a knowledgebase pipeline."""
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
@service_api_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request - no file or invalid file",
401: "Unauthorized - invalid API token",
413: "File too large",
415: "Unsupported file type",
}
)
def post(self, tenant_id: str):
"""Upload a file for use in conversations.
Accepts a single file upload via multipart/form-data.
"""
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -193,6 +193,47 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)

View File

@ -25,6 +25,7 @@ from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
@ -41,14 +42,17 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@ -67,6 +71,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
is_retry: bool = False,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload
@ -81,6 +86,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@overload
@ -95,6 +101,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@ -108,6 +115,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
@ -126,8 +134,10 @@ class PipelineGenerator(BaseAppGenerator):
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
@ -147,11 +157,12 @@ class PipelineGenerator(BaseAppGenerator):
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
@ -170,6 +181,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
original_document_id=args.get("original_document_id"),
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
@ -208,7 +220,7 @@ class PipelineGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER:
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
@ -223,16 +235,48 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_run_task.delay( # type: ignore
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,

View File

@ -122,6 +122,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,

View File

@ -257,6 +257,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
dataset_id: str
batch: str
document_id: Optional[str] = None
original_document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel):

View File

@ -0,0 +1,14 @@
from typing import Any
from pydantic import BaseModel
class RagPipelineInvokeEntity(BaseModel):
pipeline_id: str
application_generate_entity: dict[str, Any]
user_id: str
tenant_id: str
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@ -29,9 +29,7 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
@ -52,9 +50,7 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for i in range(len(texts)):
text = texts[i]
if keywords_list:
@ -239,9 +235,7 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)

View File

@ -24,6 +24,7 @@ class SystemVariableKey(StrEnum):
WORKFLOW_EXECUTION_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
ORIGINAL_DOCUMENT_ID = "original_document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import func
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -128,6 +128,8 @@ class KnowledgeIndexNode(Node):
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
@ -137,6 +139,19 @@ class KnowledgeIndexNode(Node):
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
document.indexing_latency = indexing_end_at - indexing_start_at

View File

@ -44,6 +44,7 @@ class SystemVariable(BaseModel):
conversation_id: str | None = None
dialogue_count: int | None = None
document_id: str | None = None
original_document_id: str | None = None
dataset_id: str | None = None
batch: str | None = None
datasource_type: str | None = None
@ -94,6 +95,8 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
if self.document_id is not None:
d[SystemVariableKey.DOCUMENT_ID] = self.document_id
if self.original_document_id is not None:
d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id
if self.dataset_id is not None:
d[SystemVariableKey.DATASET_ID] = self.dataset_id
if self.batch is not None:

View File

@ -95,6 +95,7 @@ dataset_detail_fields = {
"is_published": fields.Boolean,
"total_documents": fields.Integer,
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
}
dataset_query_detail_fields = {

View File

@ -0,0 +1,35 @@
"""add_pipeline_info_18
Revision ID: 0b2ca375fabe
Revises: b45e25c2d166
Create Date: 2025-09-12 14:29:38.078589
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0b2ca375fabe'
down_revision = 'b45e25c2d166'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('enable_api', sa.Boolean(), server_default=sa.text('true'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('enable_api')
# ### end Alembic commands ###

View File

@ -72,6 +72,7 @@ class Dataset(Base):
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = db.Column(StringUUID, nullable=True)
chunk_structure = db.Column(db.String(255), nullable=True)
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
@property
def total_documents(self):

View File

@ -48,6 +48,7 @@ from models.dataset import (
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@ -66,6 +67,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.add_document_to_index_task import add_document_to_index_task
@ -528,12 +530,97 @@ class DatasetService:
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# update pipeline knowledge base node data
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
return dataset
@staticmethod
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
"""
Update pipeline knowledge base node data.
"""
if dataset.runtime_mode != "rag_pipeline":
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
if not pipeline:
return
try:
rag_pipeline_service = RagPipelineService()
published_workflow = rag_pipeline_service.get_published_workflow(pipeline)
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline)
# update knowledge nodes
def update_knowledge_nodes(workflow_graph: str) -> str:
"""Update knowledge-index nodes in workflow graph."""
data: dict[str, Any] = json.loads(workflow_graph)
nodes = data.get("nodes", [])
updated = False
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
try:
knowledge_index_node_data = node.get("data", {})
knowledge_index_node_data["embedding_model"] = dataset.embedding_model
knowledge_index_node_data["embedding_model_provider"] = dataset.embedding_model_provider
knowledge_index_node_data["retrieval_model"] = dataset.retrieval_model
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
node["data"] = knowledge_index_node_data
updated = True
except Exception:
logging.exception("Failed to update knowledge node")
continue
if updated:
data["nodes"] = nodes
return json.dumps(data)
return workflow_graph
# Update published workflow
if published_workflow:
updated_graph = update_knowledge_nodes(published_workflow.graph)
if updated_graph != published_workflow.graph:
# Create new workflow version
workflow = Workflow.new(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
type=published_workflow.type,
version=str(datetime.datetime.now(datetime.UTC).replace(tzinfo=None)),
graph=updated_graph,
features=published_workflow.features,
created_by=updata_user_id,
environment_variables=published_workflow.environment_variables,
conversation_variables=published_workflow.conversation_variables,
rag_pipeline_variables=published_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
db.session.add(workflow)
# Update draft workflow
if draft_workflow:
updated_graph = update_knowledge_nodes(draft_workflow.graph)
if updated_graph != draft_workflow.graph:
draft_workflow.graph = updated_graph
db.session.add(draft_workflow)
# Commit all changes in one transaction
db.session.commit()
except Exception:
logging.exception("Failed to update pipeline knowledge base node data")
db.session.rollback()
raise
@staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data):
"""
@ -921,6 +1008,16 @@ class DatasetService:
.all()
)
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
dataset.enable_api = status
dataset.updated_by = current_user.id
dataset.updated_at = naive_utc_now()
db.session.commit()
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
@ -1253,7 +1350,7 @@ class DocumentService:
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):

View File

@ -346,18 +346,17 @@ class DatasourceProviderService:
"""
check if tenant oauth params is enabled
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
@ -365,23 +364,22 @@ class DatasourceProviderService:
"""
get tenant oauth client
"""
with Session(db.engine).no_autoflush as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
tenant_oauth_client_params = (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID

View File

@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -120,34 +119,30 @@ class FileService:
return file_size <= file_size_limit
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt"
# save file to storage
storage.save(file_key, text.encode("utf-8"))
# save file to db
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=text_name,
size=len(text),
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_by=user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=current_user.id,
used_by=user_id,
used_at=naive_utc_now(),
)
@ -225,3 +220,23 @@ class FileService:
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
content = storage.load(upload_file.key)
return content.decode("utf-8")
def delete_file(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)
session.commit()

View File

@ -0,0 +1,22 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: Mapping[str, Any]
datasource_type: str
credential_id: Optional[str] = None
is_published: bool
class PipelineRunApiEntity(BaseModel):
inputs: Mapping[str, Any]
datasource_type: str
datasource_info_list: list[Mapping[str, Any]]
start_node_id: str
is_published: bool
response_mode: str

View File

@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast
from uuid import uuid4
from flask_login import current_user
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
@ -54,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
from models.dataset import ( # type: ignore
Dataset,
Document,
DocumentPipelineExecutionLog,
Pipeline,
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@ -65,7 +73,6 @@ from models.workflow import (
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
@ -346,6 +353,8 @@ class RagPipelineService:
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
from services.dataset_service import DatasetService
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
@ -1311,3 +1320,39 @@ class RagPipelineService:
"installed_recommended_plugins": installed_plugin_list,
"uninstalled_recommended_plugins": uninstalled_plugin_list,
}
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
"""
Retry error document
"""
document_pipeline_excution_log = (
db.session.query(DocumentPipelineExecutionLog)
.filter(DocumentPipelineExecutionLog.document_id == document.id)
.first()
)
if not document_pipeline_excution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
workflow = self.get_published_workflow(pipeline)
if not workflow:
raise ValueError("Workflow not found")
PipelineGenerator().generate(
pipeline=pipeline,
workflow=workflow,
user=user,
args={
"inputs": document_pipeline_excution_log.input_data,
"start_node_id": document_pipeline_excution_log.datasource_node_id,
"datasource_type": document_pipeline_excution_log.datasource_type,
"datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)],
},
invoke_from=InvokeFrom.PUBLISHED,
streaming=False,
call_depth=0,
workflow_thread_pool_id=None,
is_retry=True,
documents=[document],
)

View File

@ -0,0 +1,173 @@
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in priority pipeline task")
raise

View File

@ -1,8 +1,11 @@
import contextvars
import json
import logging
import threading
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
@ -10,6 +13,7 @@ from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
@ -18,21 +22,18 @@ from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
pipeline_id: str,
application_generate_entity: dict,
user_id: str,
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
workflow_id: str,
streaming: bool,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
@ -48,94 +49,146 @@ def rag_pipeline_run_task(
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}"
# run with threading, thread pool size is 10
try:
with Session(db.engine) as session:
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
# Use app context to ensure Flask globals work properly
with current_app.app_context():
else:
# No more waiting tasks, clear the flag
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for thread (after setting user)
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Get Flask app object in the main thread where app context exists
flask_app = current_app._get_current_object() # type: ignore
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
# Create a wrapper function that passes user context
def _run_with_user_context():
# Don't create a new app context here - let _generate handle it
# Just ensure the user is available in contextvars
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
# Create and start worker thread
worker_thread = threading.Thread(target=_run_with_user_context)
worker_thread.start()
worker_thread.join() # Wait for worker thread to complete
end_at = time.perf_counter()
logging.info(
click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green")
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red"))
raise
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in pipeline task")
raise

View File

@ -10,32 +10,44 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
"""
Async process document
:param dataset_id:
:param document_ids:
:param user_id:
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
print("sadaddadadaaaadadadadsdsadasdadasdasda")
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
tenant_id = dataset.tenant_id
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant_id)
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
@ -87,8 +99,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"