mirror of https://github.com/langgenius/dify.git
add dataset service api enable
This commit is contained in:
parent
ad870de554
commit
80c32a130f
|
|
@ -0,0 +1,234 @@
|
|||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource plugins retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"credential_id": "Credential ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource node run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=datasource_node_run_api_entity.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_node_run_api_entity.datasource_type,
|
||||
is_published=datasource_node_run_api_entity.is_published,
|
||||
credential_id=datasource_node_run_api_entity.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"datasource_info_list": "Datasource info list",
|
||||
"start_node_id": "Start node ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Pipeline run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except Exception as ex:
|
||||
raise PipelineRunError(description=str(ex))
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||
|
||||
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - no file or invalid file",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
"""
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
|
|
@ -25,6 +25,7 @@ from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
|||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
|
|
@ -41,6 +42,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
|
|
@ -48,7 +50,10 @@ from models.enums import WorkflowRunTriggeredFrom
|
|||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -147,6 +152,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
rag_pipeline_invoke_entities = []
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
|
|
@ -223,7 +229,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
|
|
@ -232,7 +238,36 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
))
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
)
|
||||
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagPipelineInvokeEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
application_generate_entity: dict[str, Any]
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
workflow_id: str
|
||||
streaming: bool
|
||||
workflow_execution_id: str | None = None
|
||||
workflow_thread_pool_id: str | None = None
|
||||
|
|
@ -120,8 +120,7 @@ class FileService:
|
|||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
def upload_text(self, text: str, text_name: str) -> UploadFile:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
|
|
@ -225,3 +224,23 @@ class FileService:
|
|||
generator = storage.load(upload_file.key)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
def get_file_content(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
content = storage.load(upload_file.key)
|
||||
|
||||
return content.decode("utf-8")
|
||||
|
||||
def delete_file(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
session.commit()
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Any, Mapping, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasourceNodeRunApiEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: Optional[str] = None
|
||||
is_published: bool
|
||||
|
||||
class PipelineRunApiEntity(BaseModel):
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
datasource_info_list: list[Mapping[str, Any]]
|
||||
start_node_id: str
|
||||
is_published: bool
|
||||
response_mode: str
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]):
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Use app context to ensure Flask globals work properly
|
||||
with current_app.app_context():
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for thread (after setting user)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Get Flask app object in the main thread where app context exists
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Create a wrapper function that passes user context
|
||||
def _run_with_user_context():
|
||||
# Don't create a new app context here - let _generate handle it
|
||||
# Just ensure the user is available in contextvars
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
# Create and start worker thread
|
||||
worker_thread = threading.Thread(target=_run_with_user_context)
|
||||
worker_thread.start()
|
||||
worker_thread.join() # Wait for worker thread to complete
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
|
@ -10,6 +14,7 @@ from flask import current_app, g
|
|||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -18,21 +23,18 @@ from models.account import Account, Tenant
|
|||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
pipeline_id: str,
|
||||
application_generate_entity: dict,
|
||||
user_id: str,
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
streaming: bool,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
|
|
@ -48,94 +50,137 @@ def rag_pipeline_run_task(
|
|||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}"
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Use app context to ensure Flask globals work properly
|
||||
with current_app.app_context():
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for thread (after setting user)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Get Flask app object in the main thread where app context exists
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Create a wrapper function that passes user context
|
||||
def _run_with_user_context():
|
||||
# Don't create a new app context here - let _generate handle it
|
||||
# Just ensure the user is available in contextvars
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
# Create and start worker thread
|
||||
worker_thread = threading.Thread(target=_run_with_user_context)
|
||||
worker_thread.start()
|
||||
worker_thread.join() # Wait for worker thread to complete
|
||||
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green")
|
||||
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red"))
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, bytes) else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]):
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Use app context to ensure Flask globals work properly
|
||||
with current_app.app_context():
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for thread (after setting user)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Get Flask app object in the main thread where app context exists
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Create a wrapper function that passes user context
|
||||
def _run_with_user_context():
|
||||
# Don't create a new app context here - let _generate handle it
|
||||
# Just ensure the user is available in contextvars
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
# Create and start worker thread
|
||||
worker_thread = threading.Thread(target=_run_with_user_context)
|
||||
worker_thread.start()
|
||||
worker_thread.join() # Wait for worker thread to complete
|
||||
|
|
|
|||
Loading…
Reference in New Issue