From 17da96bdd8083bdd0635b765231d3c04621d7064 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 21 Jul 2025 16:43:50 +0800 Subject: [PATCH] feat: refactor datasource authentication APIs for improved credential management --- .../datasets/rag_pipeline/datasource_auth.py | 46 +++++++++--- api/services/datasource_provider_service.py | 74 +++++++++++++------ 2 files changed, 86 insertions(+), 34 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index d1e4812bcd..31a8479f25 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -149,33 +149,40 @@ class DatasourceAuth(Resource): ) return {"result": datasources}, 200 +class DatasourceAuthDeleteApi(Resource): -class DatasourceAuthUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, provider_id: str, auth_id: str): + def post(self, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name if not current_user.is_editor: raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( tenant_id=current_user.current_tenant_id, - auth_id=auth_id, + auth_id=args["credential_id"], provider=provider_name, plugin_id=plugin_id, ) return {"result": "success"}, 200 +class DatasourceAuthUpdateApi(Resource): + @setup_required @login_required @account_initialization_required - def patch(self, provider_id: str, auth_id: str): + def post(self, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() if not current_user.is_editor: raise Forbidden() @@ -183,10 +190,11 @@ class DatasourceAuthUpdateDeleteApi(Resource): datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( tenant_id=current_user.current_tenant_id, - auth_id=auth_id, + auth_id=args["credential_id"], provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, - credentials=args["credentials"], + credentials=args.get("credentials", {}), + name=args.get("name", None), ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) @@ -228,6 +236,17 @@ class DatasourceAuthOauthCustomClient(Resource): ) return {"result": "success"}, 200 + @setup_required + @login_required + @account_initialization_required + def delete(self, provider_id: str): + datasource_provider_id = DatasourceProviderID(provider_id) + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.remove_oauth_custom_client_params( + tenant_id=current_user.current_tenant_id, + datasource_provider_id=datasource_provider_id, + ) + return {"result": "success"}, 200 class DatasourceAuthDefaultApi(Resource): @setup_required @@ -237,14 +256,14 @@ class DatasourceAuthDefaultApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.set_default_datasource_provider( tenant_id=current_user.current_tenant_id, datasource_provider_id=datasource_provider_id, - credential_id=args["credential_id"], + credential_id=args["id"], ) return {"result": "success"}, 200 @@ -284,8 +303,13 @@ api.add_resource( ) api.add_resource( - DatasourceAuthUpdateDeleteApi, - "/auth/plugin/datasource//", + DatasourceAuthUpdateApi, + "/auth/plugin/datasource//update", +) + +api.add_resource( + DatasourceAuthDeleteApi, + "/auth/plugin/datasource//delete", ) api.add_resource( diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 43af7651d8..28d5b338e5 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,6 +4,7 @@ from typing import Any from flask_login import current_user from sqlalchemy.orm import Session +from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper import encrypter from core.helper.name_generator import generate_incremental_name @@ -29,6 +30,18 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() + def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID): + """ + remove oauth custom client params + """ + with Session(db.engine) as session: + session.query(DatasourceOauthTenantParamConfig).filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ).delete() + session.commit() + def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: """ get default credentials @@ -512,6 +525,10 @@ class DatasourceProviderService: credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) + redirect_uri = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" + f"{datasource_provider_id}/datasource/callback" + ) datasource_credentials.append( { "provider": datasource.provider, @@ -542,6 +559,7 @@ class DatasourceProviderService: tenant_id, datasource_provider_id ), "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id), + "redirect_uri": redirect_uri } if datasource.declaration.oauth_schema else None, @@ -594,38 +612,50 @@ class DatasourceProviderService: return copy_credentials_list def update_datasource_credentials( - self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict + self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None ) -> None: """ update datasource credentials. """ - credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, - user_id=current_user.id, - provider=provider, - plugin_id=plugin_id, - credentials=credentials, - ) - if credential_valid: - # Get all provider configurations of the current workspace + with Session(db.engine) as session: datasource_provider = ( - db.session.query(DatasourceProvider) + session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .first() ) - if not datasource_provider: raise ValueError("Datasource provider not found") - else: - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}", - credential_type=datasource_provider.auth_type, - ) + # update name + if name and name != datasource_provider.name: + if ( + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id) + .count() + > 0 + ): + raise ValueError("name is already exists") + datasource_provider.name = name + + # update credentials + if credentials: + try: + self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials, + ) + except Exception as e: + raise ValueError(f"Failed to validate credentials: {str(e)}") + original_credentials = datasource_provider.encrypted_credentials for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value + if key in self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=datasource_provider.auth_type, + ): if value == HIDDEN_VALUE and key in original_credentials: original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) credentials[key] = encrypter.encrypt_token(tenant_id, original_value) @@ -633,9 +663,7 @@ class DatasourceProviderService: credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider.encrypted_credentials = credentials - db.session.commit() - else: - raise CredentialsValidateFailedError() + session.commit() def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """