add dataset service api enable

This commit is contained in:
jyong 2025-09-14 20:43:49 +08:00
parent ad870de554
commit 80c32a130f
8 changed files with 625 additions and 92 deletions

View File

@ -0,0 +1,234 @@
import string
import uuid
from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
params={
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
}
)
@service_api_ns.doc(
responses={
200: "Datasource plugins retrieved successfully",
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id,
dataset_id=dataset_id,
is_published=is_published
)
return datasource_plugins, 200
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"credential_id": "Credential ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
}
)
@service_api_ns.doc(
responses={
200: "Datasource node run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=datasource_node_run_api_entity.inputs,
account=current_user,
datasource_type=datasource_node_run_api_entity.datasource_type,
is_published=datasource_node_run_api_entity.is_published,
credential_id=datasource_node_run_api_entity.credential_id,
)
)
)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"datasource_info_list": "Datasource info list",
"start_node_id": "Start node ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
"streaming": "Whether to stream the response(streaming or blocking), default: streaming"
}
)
@service_api_ns.doc(
responses={
200: "Pipeline run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json")
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
)
return helper.compact_generate_response(response)
except Exception as ex:
raise PipelineRunError(description=str(ex))
@service_api_ns.route("/datasets/pipeline/file-upload")
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
"""Resource for uploading a file to a knowledgebase pipeline."""
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
@service_api_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request - no file or invalid file",
401: "Unauthorized - invalid API token",
413: "File too large",
415: "Unsupported file type",
}
)
def post(self, tenant_id: str):
"""Upload a file for use in conversations.
Accepts a single file upload via multipart/form-data.
"""
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -25,6 +25,7 @@ from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
@ -41,6 +42,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
@ -48,7 +50,10 @@ from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@ -147,6 +152,7 @@ class PipelineGenerator(BaseAppGenerator):
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
@ -223,7 +229,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
@ -232,7 +238,36 @@ class PipelineGenerator(BaseAppGenerator):
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
))
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
)
# return batch, dataset, documents
return {
"batch": batch,

View File

@ -0,0 +1,14 @@
from typing import Any
from pydantic import BaseModel
class RagPipelineInvokeEntity(BaseModel):
pipeline_id: str
application_generate_entity: dict[str, Any]
user_id: str
tenant_id: str
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@ -120,8 +120,7 @@ class FileService:
return file_size <= file_size_limit
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
def upload_text(self, text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
@ -225,3 +224,23 @@ class FileService:
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
content = storage.load(upload_file.key)
return content.decode("utf-8")
def delete_file(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)
session.commit()

View File

@ -0,0 +1,19 @@
from typing import Any, Mapping, Optional
from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: Mapping[str, Any]
datasource_type: str
credential_id: Optional[str] = None
is_published: bool
class PipelineRunApiEntity(BaseModel):
inputs: Mapping[str, Any]
datasource_type: str
datasource_info_list: list[Mapping[str, Any]]
start_node_id: str
is_published: bool
response_mode: str

View File

@ -0,0 +1,167 @@
import contextvars
import json
import logging
import threading
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]):
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Use app context to ensure Flask globals work properly
with current_app.app_context():
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for thread (after setting user)
context = contextvars.copy_context()
# Get Flask app object in the main thread where app context exists
flask_app = current_app._get_current_object() # type: ignore
# Create a wrapper function that passes user context
def _run_with_user_context():
# Don't create a new app context here - let _generate handle it
# Just ensure the user is available in contextvars
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
# Create and start worker thread
worker_thread = threading.Thread(target=_run_with_user_context)
worker_thread.start()
worker_thread.join() # Wait for worker thread to complete

View File

@ -1,8 +1,12 @@
import contextvars
import json
import logging
import threading
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
@ -10,6 +14,7 @@ from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
@ -18,21 +23,18 @@ from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
pipeline_id: str,
application_generate_entity: dict,
user_id: str,
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
workflow_id: str,
streaming: bool,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
@ -48,94 +50,137 @@ def rag_pipeline_run_task(
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
logging.info(click.style(f"Start run rag pipeline: {pipeline_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"rag_pipeline_run_{pipeline_id}_{user_id}"
# run with threading, thread pool size is 10
try:
with Session(db.engine) as session:
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Use app context to ensure Flask globals work properly
with current_app.app_context():
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for thread (after setting user)
context = contextvars.copy_context()
# Get Flask app object in the main thread where app context exists
flask_app = current_app._get_current_object() # type: ignore
# Create a wrapper function that passes user context
def _run_with_user_context():
# Don't create a new app context here - let _generate handle it
# Just ensure the user is available in contextvars
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
# Create and start worker thread
worker_thread = threading.Thread(target=_run_with_user_context)
worker_thread.start()
worker_thread.join() # Wait for worker thread to complete
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(f"Rag pipeline run: {pipeline_id} completed. Latency: {end_at - start_at}s", fg="green")
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline {pipeline_id}", fg="red"))
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
redis_client.delete(indexing_cache_key)
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, bytes) else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any]):
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Use app context to ensure Flask globals work properly
with current_app.app_context():
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for thread (after setting user)
context = contextvars.copy_context()
# Get Flask app object in the main thread where app context exists
flask_app = current_app._get_current_object() # type: ignore
# Create a wrapper function that passes user context
def _run_with_user_context():
# Don't create a new app context here - let _generate handle it
# Just ensure the user is available in contextvars
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
# Create and start worker thread
worker_thread = threading.Thread(target=_run_with_user_context)
worker_thread.start()
worker_thread.join() # Wait for worker thread to complete