From 875aea1c22e20131337cce446e43770c7bb83eb0 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 30 Jul 2025 13:39:04 +0800 Subject: [PATCH] feat: datasource reauthentication --- .../datasets/rag_pipeline/datasource_auth.py | 35 ++++++++---- api/services/datasource_provider_service.py | 54 ++++++++++++++----- api/services/plugin/oauth_service.py | 10 +++- 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 44954278d6..d67af182cd 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -32,6 +32,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): if not current_user.is_editor: raise Forbidden() + credential_id = request.args.get("credential_id") datasource_provider_id = DatasourceProviderID(provider_id) provider_name = datasource_provider_id.provider_name plugin_id = datasource_provider_id.plugin_id @@ -43,7 +44,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): raise ValueError(f"No OAuth Client Config for {provider_id}") context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + user_id=current_user.id, + tenant_id=tenant_id, + plugin_id=plugin_id, + provider=provider_name, + credential_id=credential_id, ) oauth_handler = OAuthHandler() redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" @@ -98,13 +103,24 @@ class DatasourceOAuthCallback(Resource): system_credentials=oauth_client_params, request=request, ) - datasource_provider_service.add_datasource_oauth_provider( - tenant_id=tenant_id, - provider_id=datasource_provider_id, - avatar_url=oauth_response.metadata.get("avatar_url") or None, - name=oauth_response.metadata.get("name") or None, - credentials=dict(oauth_response.credentials), - ) + credential_id = context.get("credential_id") + if credential_id: + datasource_provider_service.reauthorize_datasource_oauth_provider( + tenant_id=tenant_id, + provider_id=datasource_provider_id, + avatar_url=oauth_response.metadata.get("avatar_url") or None, + name=oauth_response.metadata.get("name") or None, + credentials=dict(oauth_response.credentials), + credential_id=context.get("credential_id"), + ) + else: + datasource_provider_service.add_datasource_oauth_provider( + tenant_id=tenant_id, + provider_id=datasource_provider_id, + avatar_url=oauth_response.metadata.get("avatar_url") or None, + name=oauth_response.metadata.get("name") or None, + credentials=dict(oauth_response.credentials), + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") @@ -208,7 +224,8 @@ class DatasourceAuthListApi(Resource): tenant_id=current_user.current_tenant_id ) return {"result": jsonable_encoder(datasources)}, 200 - + + class DatasourceHardCodeAuthListApi(Resource): @setup_required @login_required diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 25966ed41a..dac406360f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -82,19 +82,16 @@ class DatasourceProviderService: if key in credential_secret_variables: copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) return copy_credentials - - def get_default_real_credential( - self, tenant_id: str, provider: str, plugin_id: str - ) -> dict[str, Any]: + + def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: """ get default credential """ with Session(db.engine) as session: datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - is_default=True, - provider=provider, - plugin_id=plugin_id).first() + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, is_default=True, provider=provider, plugin_id=plugin_id) + .first() ) if not datasource_provider: return {} @@ -357,6 +354,35 @@ class DatasourceProviderService: f"{credential_type.get_name()}", ) + def reauthorize_datasource_oauth_provider( + self, + name: str | None, + tenant_id: str, + provider_id: DatasourceProviderID, + avatar_url: str | None, + credentials: dict, + credential_id: str, + ) -> None: + """ + update datasource oauth provider + """ + with Session(db.engine) as session: + target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + if target_provider is None: + raise ValueError("provider not found") + + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + target_provider.encrypted_credentials = credentials + target_provider.avatar_url = avatar_url or target_provider.avatar_url + target_provider.name = name or target_provider.name + session.commit() + def add_datasource_oauth_provider( self, name: str | None, @@ -625,7 +651,7 @@ class DatasourceProviderService: } ) return datasource_credentials - + def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]: """ get hard code datasource credentials. @@ -637,14 +663,16 @@ class DatasourceProviderService: datasources = manager.fetch_installed_datasource_providers(tenant_id) datasource_credentials = [] for datasource in datasources: - if datasource.plugin_id in ["langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource"]: + if datasource.plugin_id in [ + "langgenius/firecrawl_datasource", + "langgenius/notion_datasource", + "langgenius/jina_datasource", + ]: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) - redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" - ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" datasource_credentials.append( { "provider": datasource.provider, diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index b84dd0afc5..4a09e71504 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -11,7 +11,13 @@ class OAuthProxyService(BasePluginClient): __KEY_PREFIX__ = "oauth_proxy_context:" @staticmethod - def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): + def create_proxy_context( + user_id: str, + tenant_id: str, + plugin_id: str, + provider: str, + credential_id: str | None = None, + ): """ Create a proxy context for an OAuth 2.0 authorization request. @@ -31,6 +37,8 @@ class OAuthProxyService(BasePluginClient): "tenant_id": tenant_id, "provider": provider, } + if credential_id: + data["credential_id"] = credential_id redis_client.setex( f"{OAuthProxyService.__KEY_PREFIX__}{context_id}", OAuthProxyService.__MAX_AGE__,