mirror of https://github.com/langgenius/dify.git
Feat/add dataset service api enable (#25765)
This commit is contained in:
commit
1bf0dbc5d6
|
|
@ -741,6 +741,19 @@ class DatasetApiDeleteApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<str:status>")
|
||||
class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-base-info")
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@api.doc("get_dataset_api_base_info")
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Get draft rag pipeline's workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
|
|
@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
|
@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource):
|
|||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
|
|
@ -161,10 +158,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -198,10 +192,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -235,10 +226,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -272,10 +260,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -285,6 +270,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
|
@ -394,10 +380,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -439,7 +422,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -482,7 +465,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -514,7 +497,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
|
@ -533,7 +516,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
|
|
@ -553,7 +536,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
|
@ -587,7 +570,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
|
|
@ -605,10 +588,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -651,7 +631,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -700,7 +680,7 @@ class RagPipelineByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -756,7 +736,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -781,7 +761,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -806,7 +786,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -831,7 +811,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -953,7 +933,7 @@ class RagPipelineTransformApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
|
|
@ -972,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
|
|
@ -202,7 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/metadata/built-in")
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_built_in_fields")
|
||||
@service_api_ns.doc(description="Get all built-in metadata fields")
|
||||
|
|
@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
|
|
|||
|
|
@ -0,0 +1,239 @@
|
|||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource plugins retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"credential_id": "Credential ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource node run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=datasource_node_run_api_entity.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_node_run_api_entity.datasource_type,
|
||||
is_published=datasource_node_run_api_entity.is_published,
|
||||
credential_id=datasource_node_run_api_entity.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"datasource_info_list": "Datasource info list",
|
||||
"start_node_id": "Start node ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Pipeline run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
parser.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except Exception as ex:
|
||||
raise PipelineRunError(description=str(ex))
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||
|
||||
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - no file or invalid file",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
"""
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
|
|
@ -193,6 +193,47 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
|||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
|
|
@ -41,14 +42,17 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -67,6 +71,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -81,6 +86,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -95,6 +101,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -108,6 +115,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
|
||||
|
|
@ -126,8 +134,10 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
documents: list[Document] = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
|
|
@ -147,11 +157,12 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
rag_pipeline_invoke_entities = []
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
document_id = documents[i].id
|
||||
document_id = args.get("original_document_id") or None
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
||||
document_id = document_id or documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
datasource_type=datasource_type,
|
||||
|
|
@ -170,6 +181,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=dataset.id,
|
||||
original_document_id=args.get("original_document_id"),
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
|
|
@ -208,7 +220,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=contextvars.copy_context(),
|
||||
|
|
@ -223,16 +235,48 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
rag_pipeline_invoke_entities.append(
|
||||
RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
document_id=self.application_generate_entity.document_id,
|
||||
original_document_id=self.application_generate_entity.original_document_id,
|
||||
batch=self.application_generate_entity.batch,
|
||||
dataset_id=self.application_generate_entity.dataset_id,
|
||||
datasource_type=self.application_generate_entity.datasource_type,
|
||||
|
|
|
|||
|
|
@ -257,6 +257,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
original_document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagPipelineInvokeEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
application_generate_entity: dict[str, Any]
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
workflow_id: str
|
||||
streaming: bool
|
||||
workflow_execution_id: str | None = None
|
||||
workflow_thread_pool_id: str | None = None
|
||||
|
|
@ -29,9 +29,7 @@ class Jieba(BaseKeyword):
|
|||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
|
|
@ -52,9 +50,7 @@ class Jieba(BaseKeyword):
|
|||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
|
|
@ -239,9 +235,7 @@ class Jieba(BaseKeyword):
|
|||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||
segment.keywords = list(keywords)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class SystemVariableKey(StrEnum):
|
|||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
ORIGINAL_DOCUMENT_ID = "original_document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -128,6 +128,8 @@ class KnowledgeIndexNode(Node):
|
|||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
|
||||
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
|
|
@ -137,6 +139,19 @@ class KnowledgeIndexNode(Node):
|
|||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class SystemVariable(BaseModel):
|
|||
conversation_id: str | None = None
|
||||
dialogue_count: int | None = None
|
||||
document_id: str | None = None
|
||||
original_document_id: str | None = None
|
||||
dataset_id: str | None = None
|
||||
batch: str | None = None
|
||||
datasource_type: str | None = None
|
||||
|
|
@ -94,6 +95,8 @@ class SystemVariable(BaseModel):
|
|||
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
|
||||
if self.document_id is not None:
|
||||
d[SystemVariableKey.DOCUMENT_ID] = self.document_id
|
||||
if self.original_document_id is not None:
|
||||
d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id
|
||||
if self.dataset_id is not None:
|
||||
d[SystemVariableKey.DATASET_ID] = self.dataset_id
|
||||
if self.batch is not None:
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ dataset_detail_fields = {
|
|||
"is_published": fields.Boolean,
|
||||
"total_documents": fields.Integer,
|
||||
"total_available_documents": fields.Integer,
|
||||
"enable_api": fields.Boolean,
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
"""add_pipeline_info_18
|
||||
|
||||
Revision ID: 0b2ca375fabe
|
||||
Revises: b45e25c2d166
|
||||
Create Date: 2025-09-12 14:29:38.078589
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0b2ca375fabe'
|
||||
down_revision = 'b45e25c2d166'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('enable_api', sa.Boolean(), server_default=sa.text('true'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('enable_api')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -72,6 +72,7 @@ class Dataset(Base):
|
|||
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
|
||||
pipeline_id = db.Column(StringUUID, nullable=True)
|
||||
chunk_structure = db.Column(db.String(255), nullable=True)
|
||||
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from models.dataset import (
|
|||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models.workflow import Workflow
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
|
|
@ -66,6 +67,7 @@ from services.errors.document import DocumentIndexingError
|
|||
from services.errors.file import FileNotExistsError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
|
|
@ -528,12 +530,97 @@ class DatasetService:
|
|||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# update pipeline knowledge base node data
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
|
||||
"""
|
||||
Update pipeline knowledge base node data.
|
||||
"""
|
||||
if dataset.runtime_mode != "rag_pipeline":
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
return
|
||||
|
||||
try:
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
published_workflow = rag_pipeline_service.get_published_workflow(pipeline)
|
||||
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline)
|
||||
|
||||
# update knowledge nodes
|
||||
def update_knowledge_nodes(workflow_graph: str) -> str:
|
||||
"""Update knowledge-index nodes in workflow graph."""
|
||||
data: dict[str, Any] = json.loads(workflow_graph)
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
updated = False
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
try:
|
||||
knowledge_index_node_data = node.get("data", {})
|
||||
knowledge_index_node_data["embedding_model"] = dataset.embedding_model
|
||||
knowledge_index_node_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
knowledge_index_node_data["retrieval_model"] = dataset.retrieval_model
|
||||
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
|
||||
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
|
||||
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
|
||||
node["data"] = knowledge_index_node_data
|
||||
updated = True
|
||||
except Exception:
|
||||
logging.exception("Failed to update knowledge node")
|
||||
continue
|
||||
|
||||
if updated:
|
||||
data["nodes"] = nodes
|
||||
return json.dumps(data)
|
||||
return workflow_graph
|
||||
|
||||
# Update published workflow
|
||||
if published_workflow:
|
||||
updated_graph = update_knowledge_nodes(published_workflow.graph)
|
||||
if updated_graph != published_workflow.graph:
|
||||
# Create new workflow version
|
||||
workflow = Workflow.new(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
type=published_workflow.type,
|
||||
version=str(datetime.datetime.now(datetime.UTC).replace(tzinfo=None)),
|
||||
graph=updated_graph,
|
||||
features=published_workflow.features,
|
||||
created_by=updata_user_id,
|
||||
environment_variables=published_workflow.environment_variables,
|
||||
conversation_variables=published_workflow.conversation_variables,
|
||||
rag_pipeline_variables=published_workflow.rag_pipeline_variables,
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
)
|
||||
db.session.add(workflow)
|
||||
|
||||
# Update draft workflow
|
||||
if draft_workflow:
|
||||
updated_graph = update_knowledge_nodes(draft_workflow.graph)
|
||||
if updated_graph != draft_workflow.graph:
|
||||
draft_workflow.graph = updated_graph
|
||||
db.session.add(draft_workflow)
|
||||
|
||||
# Commit all changes in one transaction
|
||||
db.session.commit()
|
||||
|
||||
except Exception:
|
||||
logging.exception("Failed to update pipeline knowledge base node data")
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _handle_indexing_technique_change(dataset, data, filtered_data):
|
||||
"""
|
||||
|
|
@ -921,6 +1008,16 @@ class DatasetService:
|
|||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset.enable_api = status
|
||||
dataset.updated_by = current_user.id
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
|
|
@ -1253,7 +1350,7 @@ class DocumentService:
|
|||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
||||
|
||||
@staticmethod
|
||||
def sync_website_document(dataset_id: str, document: Document):
|
||||
|
|
|
|||
|
|
@ -346,18 +346,17 @@ class DatasourceProviderService:
|
|||
"""
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
return (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
|
||||
|
|
@ -365,23 +364,22 @@ class DatasourceProviderService:
|
|||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
tenant_oauth_client_params = (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
if mask:
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
else:
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
.first()
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
if mask:
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
else:
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
|
||||
def get_oauth_encrypter(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
|
|
@ -120,34 +119,30 @@ class FileService:
|
|||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
|
||||
file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt"
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, text.encode("utf-8"))
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by=current_user.id,
|
||||
created_by=user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
used_by=user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
|
@ -225,3 +220,23 @@ class FileService:
|
|||
generator = storage.load(upload_file.key)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
def get_file_content(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
content = storage.load(upload_file.key)
|
||||
|
||||
return content.decode("utf-8")
|
||||
|
||||
def delete_file(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasourceNodeRunApiEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: Optional[str] = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
class PipelineRunApiEntity(BaseModel):
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
datasource_info_list: list[Mapping[str, Any]]
|
||||
start_node_id: str
|
||||
is_published: bool
|
||||
response_mode: str
|
||||
|
|
@ -5,7 +5,7 @@ import threading
|
|||
import time
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
|
|
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
|||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
|
|
@ -54,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
|
||||
from models.dataset import ( # type: ignore
|
||||
Dataset,
|
||||
Document,
|
||||
DocumentPipelineExecutionLog,
|
||||
Pipeline,
|
||||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
|
|
@ -65,7 +73,6 @@ from models.workflow import (
|
|||
WorkflowType,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.dataset_service import DatasetService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
|
|
@ -346,6 +353,8 @@ class RagPipelineService:
|
|||
|
||||
graph = workflow.graph_dict
|
||||
nodes = graph.get("nodes", [])
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = node.get("data", {})
|
||||
|
|
@ -1311,3 +1320,39 @@ class RagPipelineService:
|
|||
"installed_recommended_plugins": installed_plugin_list,
|
||||
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
||||
}
|
||||
|
||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_excution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_excution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
workflow = self.get_published_workflow(pipeline)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
PipelineGenerator().generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": document_pipeline_excution_log.input_data,
|
||||
"start_node_id": document_pipeline_excution_log.datasource_node_id,
|
||||
"datasource_type": document_pipeline_excution_log.datasource_type,
|
||||
"datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)],
|
||||
},
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
is_retry=True,
|
||||
documents=[document],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,173 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in priority pipeline task")
|
||||
raise
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
|
@ -10,6 +13,7 @@ from flask import current_app, g
|
|||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -18,21 +22,18 @@ from models.account import Account, Tenant
|
|||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
pipeline_id: str,
|
||||
application_generate_entity: dict,
|
||||
user_id: str,
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
streaming: bool,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
|
|
@ -48,94 +49,146 @@ def rag_pipeline_run_task(
|
|||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}"
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
# Use app context to ensure Flask globals work properly
|
||||
with current_app.app_context():
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for thread (after setting user)
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Get Flask app object in the main thread where app context exists
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
# Create a wrapper function that passes user context
|
||||
def _run_with_user_context():
|
||||
# Don't create a new app context here - let _generate handle it
|
||||
# Just ensure the user is available in contextvars
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
# Create and start worker thread
|
||||
worker_thread = threading.Thread(target=_run_with_user_context)
|
||||
worker_thread.start()
|
||||
worker_thread.join() # Wait for worker thread to complete
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -10,32 +10,44 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
:param user_id:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
print("sadaddadadaaaadadadadsdsadasdadasdasda")
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
tenant_id = dataset.tenant_id
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
|
|
@ -87,8 +99,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
|
|
|
|||
Loading…
Reference in New Issue