notion fix

This commit is contained in:
jyong 2025-08-07 11:13:02 +08:00
parent 218e778099
commit ca8f80ee33
3 changed files with 21 additions and 7 deletions

View File

@ -405,6 +405,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
@ -424,6 +425,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
account=current_user,
datasource_type=datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
)
)
)
@ -448,6 +450,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
@ -467,6 +470,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
account=current_user,
datasource_type=datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
)
)
)

View File

@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy.orm import Session
@ -71,15 +71,23 @@ class DatasourceProviderService:
return copy_credentials
def get_real_credential_by_id(
self, tenant_id: str, credential_id: str, provider: str, plugin_id: str
self, tenant_id: str, credential_id: Optional[str], provider: str, plugin_id: str
) -> dict[str, Any]:
"""
get credential by id
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
)
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
encrypted_credentials = datasource_provider.encrypted_credentials

View File

@ -470,6 +470,7 @@ class RagPipelineService:
account: Account,
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
) -> Generator[BaseDatasourceEvent, None, None]:
"""
Run published workflow datasource
@ -521,13 +522,14 @@ class RagPipelineService:
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
credentials = datasource_provider_service.get_real_credential_by_id(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=credential_id,
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime.runtime.credentials = credentials
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)