diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 8c0f20ce2d..a65efb750e 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,3 +1,4 @@ +from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -43,6 +44,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") + return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 5bede09a64..64fa97197d 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -120,6 +120,47 @@ class DatasourceProviderService: return copy_credentials_list + def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider_id: provider id + :return: + """ + # Get all provider configurations of the current workspace + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) + + return copy_credentials_list + + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: """ update datasource credentials. diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index ac7df87586..cb42224c60 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -43,6 +43,7 @@ from models.account import Account from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser +from models.oauth import DatasourceProvider from models.workflow import ( Workflow, WorkflowNodeExecutionTriggeredFrom, @@ -50,6 +51,7 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService +from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, @@ -442,6 +444,7 @@ class RagPipelineService: for key, value in datasource_parameters.items(): if not user_inputs.get(key): user_inputs[key] = value["value"] + from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( @@ -450,6 +453,14 @@ class RagPipelineService: tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)