From d2750f1a028a92732119d0e10fab8869c35cc31c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 14:22:00 +0800 Subject: [PATCH] r2 --- .../website_crawl/website_crawl_provider.py | 2 + api/services/datasource_provider_service.py | 41 +++++++++++++ api/services/rag_pipeline/rag_pipeline.py | 61 +++++++++++-------- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 8c0f20ce2d..a65efb750e 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,3 +1,4 @@ +from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -43,6 +44,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") + return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 5bede09a64..64fa97197d 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -120,6 +120,47 @@ class DatasourceProviderService: return copy_credentials_list + def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider_id: provider id + :return: + """ + # Get all provider configurations of the current workspace + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) + + return copy_credentials_list + + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: """ update datasource credentials. diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d899e89b02..cb42224c60 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -43,6 +43,7 @@ from models.account import Account from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser +from models.oauth import DatasourceProvider from models.workflow import ( Workflow, WorkflowNodeExecutionTriggeredFrom, @@ -50,6 +51,7 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService +from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, @@ -442,6 +444,7 @@ class RagPipelineService: for key, value in datasource_parameters.items(): if not user_inputs.get(key): user_inputs[key] = value["value"] + from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( @@ -450,32 +453,40 @@ class RagPipelineService: tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [page.model_dump() for page in online_document_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } - if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - - elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [result.model_dump() for result in website_crawl_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - else: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [result.model_dump() for result in website_crawl_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]