From ca8f80ee334951c9d705e4fb5496e59a15099761 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 7 Aug 2025 11:13:02 +0800 Subject: [PATCH] notion fix --- .../rag_pipeline/rag_pipeline_workflow.py | 4 ++++ api/services/datasource_provider_service.py | 18 +++++++++++++----- api/services/rag_pipeline/rag_pipeline.py | 6 ++++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 99dae3cfc7..ea1114cbe0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -405,6 +405,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("credential_id", type=str, required=False, location="json") args = parser.parse_args() inputs = args.get("inputs") @@ -424,6 +425,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): account=current_user, datasource_type=datasource_type, is_published=False, + credential_id=args.get("credential_id"), ) ) ) @@ -448,6 +450,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("credential_id", type=str, required=False, location="json") args = parser.parse_args() inputs = args.get("inputs") @@ -467,6 +470,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): account=current_user, datasource_type=datasource_type, is_published=False, + credential_id=args.get("credential_id"), ) ) ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index adf63ca5c6..a0b61c758e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Optional from flask_login import current_user from sqlalchemy.orm import Session @@ -71,15 +71,23 @@ class DatasourceProviderService: return copy_credentials def get_real_credential_by_id( - self, tenant_id: str, credential_id: str, provider: str, plugin_id: str + self, tenant_id: str, credential_id: Optional[str], provider: str, plugin_id: str ) -> dict[str, Any]: """ get credential by id """ with Session(db.engine) as session: - datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() - ) + if credential_id: + datasource_provider = ( + session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + ) + else: + datasource_provider = ( + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) + .first() + ) if not datasource_provider: return {} encrypted_credentials = datasource_provider.encrypted_credentials diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 8c70b80cb3..1aa030f58c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -470,6 +470,7 @@ class RagPipelineService: account: Account, datasource_type: str, is_published: bool, + credential_id: Optional[str] = None, ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource @@ -521,13 +522,14 @@ class RagPipelineService: datasource_type=DatasourceProviderType(datasource_type), ) datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( + credentials = datasource_provider_service.get_real_credential_by_id( tenant_id=pipeline.tenant_id, provider=datasource_node_data.get("provider_name"), plugin_id=datasource_node_data.get("plugin_id"), + credential_id=credential_id, ) if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") + datasource_runtime.runtime.credentials = credentials match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)