From 69529fb16dcf68a6f0e55967bf732e27cfe2af52 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 00:37:27 +0800 Subject: [PATCH] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 12 +++++++ api/core/plugin/entities/plugin_daemon.py | 1 + api/models/__init__.py | 3 ++ api/models/oauth.py | 1 + api/services/datasource_provider_service.py | 36 +++++++++++-------- .../rag_pipeline_manage_service.py | 11 +++++- 6 files changed, 49 insertions(+), 15 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 8894babcf7..ceb7a277e4 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -108,6 +108,18 @@ class DatasourceAuth(Resource): raise ValueError(str(ex)) return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + datasource_provider_service = DatasourceProviderService() + datasources = datasource_provider_service.get_datasource_credentials( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id + ) + return {"result": datasources}, 200 class DatasourceAuthDeleteApi(Resource): @setup_required diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 6644706757..cc7dfb58ab 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str + is_authorized: bool = False declaration: DatasourceProviderEntityWithPlugin diff --git a/api/models/__init__.py b/api/models/__init__.py index f652449e98..63fe2747ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -56,6 +56,7 @@ from .model import ( TraceAppConfig, UploadFile, ) +from .oauth import DatasourceOauthParamConfig, DatasourceProvider from .provider import ( LoadBalancingModelConfig, Provider, @@ -123,6 +124,8 @@ __all__ = [ "DatasetProcessRule", "DatasetQuery", "DatasetRetrieverResource", + "DatasourceOauthParamConfig", + "DatasourceProvider", "DifySetup", "Document", "DocumentSegment", diff --git a/api/models/oauth.py b/api/models/oauth.py index fefe743195..d662a4b50c 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -28,6 +28,7 @@ class DatasourceProvider(Base): db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 54abc64547..ef9a56a66e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -92,7 +92,7 @@ class DatasourceProviderService: return secret_input_form_variables - def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]: + def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ get datasource credentials. @@ -102,22 +102,30 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, - plugin_id=plugin_id).first() - if not datasource_provider: - return None - encrypted_credentials = datasource_provider.encrypted_credentials - # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + plugin_id=plugin_id).all() + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) - # Obfuscate provider credentials - copy_credentials = encrypted_credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.obfuscated_token(value) + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) - return copy_credentials + return copy_credentials_list def remove_datasource_credentials(self, diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py index 4d8d69f913..df6085fafa 100644 --- a/api/services/rag_pipeline/rag_pipeline_manage_service.py +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -1,5 +1,6 @@ from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.plugin.impl.datasource import PluginDatasourceManager +from services.datasource_provider_service import DatasourceProviderService class RagPipelineManageService: @@ -11,4 +12,12 @@ class RagPipelineManageService: # get all builtin providers manager = PluginDatasourceManager() - return manager.fetch_datasource_providers(tenant_id) + datasources = manager.fetch_datasource_providers(tenant_id) + for datasource in datasources: + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, + provider=datasource.provider, + plugin_id=datasource.plugin_id) + if credentials: + datasource.is_authorized = True + return datasources