diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fef10b79a7..fa4020d7db 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -300,6 +300,86 @@ class PublishedRagPipelineRunApi(Resource): except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) +class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + job_id = args.get("job_id") + if job_id == None: + raise ValueError("missing job_id") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node_status( + pipeline=pipeline, + node_id=node_id, + job_id=job_id, + account=current_user, + datasource_type=datasource_type, + is_published=True + ) + + return result + +class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + job_id = args.get("job_id") + if job_id == None: + raise ValueError("missing job_id") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node_status( + pipeline=pipeline, + node_id=node_id, + job_id=job_id, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + + return result + class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @@ -894,6 +974,14 @@ api.add_resource( RagPipelinePublishedDatasourceNodeRunApi, "/rag/pipelines//workflows/published/datasource/nodes//run", ) +api.add_resource( + RagPipelinePublishedDatasourceNodeRunStatusApi, + "/rag/pipelines//workflows/published/datasource/nodes//run-status", +) +api.add_resource( + RagPipelineDraftDatasourceNodeRunStatusApi, + "/rag/pipelines//workflows/draft/datasource/nodes//run-status", +) api.add_resource( RagPipelineDrafDatasourceNodeRunApi, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 6a9fc5d9f9..647d8f9a8c 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -304,4 +304,6 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: list[WebSiteInfo] + result: Optional[list[WebSiteInfo]] = [] + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index cb42224c60..1d4e279d2a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -415,6 +415,67 @@ class RagPipelineService: return workflow_node_execution + def run_datasource_workflow_node_status( + self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool + ) -> dict: + """ + Run published workflow datasource + """ + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + start_at = time.perf_counter() + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters={"job_id": job_id}, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [result.model_dump() for result in website_crawl_result.result], + "job_id": website_crawl_result.job_id, + "status": website_crawl_result.status, + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + + def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool ) -> dict: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index a89a834906..6bd0e554ab 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -94,7 +94,6 @@ services: PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_WORKING_PATH: ${PLUGIN_WORKING_PATH:-/app/storage/cwd} - FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -127,6 +126,7 @@ services: VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem + FORCE_VERIFYING_SIGNATURE: false ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"