From ba7f0b3004c3d06f94dd57eb609ff5f2f27d09bc Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 21 Jul 2025 18:51:46 +0800 Subject: [PATCH] feat: enhance datasource authentication by improving credential handling and updating API parameters --- .../datasets/rag_pipeline/datasource_auth.py | 36 +++++----- api/core/tools/entities/tool_entities.py | 4 +- api/services/datasource_provider_service.py | 69 +++++++++---------- 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 71e289185f..933a8fb9c9 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -134,7 +134,7 @@ class DatasourceAuth(Resource): ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + return {"result": "success"}, 200 @setup_required @login_required @@ -149,8 +149,8 @@ class DatasourceAuth(Resource): ) return {"result": datasources}, 200 -class DatasourceAuthDeleteApi(Resource): +class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @account_initialization_required @@ -172,8 +172,8 @@ class DatasourceAuthDeleteApi(Resource): ) return {"result": "success"}, 200 -class DatasourceAuthUpdateApi(Resource): +class DatasourceAuthUpdateApi(Resource): @setup_required @login_required @account_initialization_required @@ -186,19 +186,15 @@ class DatasourceAuthUpdateApi(Resource): args = parser.parse_args() if not current_user.is_editor: raise Forbidden() - try: - datasource_provider_service = DatasourceProviderService() - datasource_provider_service.update_datasource_credentials( - tenant_id=current_user.current_tenant_id, - auth_id=args["credential_id"], - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - credentials=args.get("credentials", {}), - name=args.get("name", None), - ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) - + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.update_datasource_credentials( + tenant_id=current_user.current_tenant_id, + auth_id=args["credential_id"], + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + credentials=args.get("credentials", {}), + name=args.get("name", None), + ) return {"result": "success"}, 201 @@ -223,7 +219,7 @@ class DatasourceAuthOauthCustomClient(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") - parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") args = parser.parse_args() datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -231,7 +227,7 @@ class DatasourceAuthOauthCustomClient(Resource): tenant_id=current_user.current_tenant_id, datasource_provider_id=datasource_provider_id, client_params=args.get("client_params", {}), - enabled=args.get("enabled", False), + enabled=args.get("enable_oauth_custom_client", False), ) return {"result": "success"}, 200 @@ -247,6 +243,7 @@ class DatasourceAuthOauthCustomClient(Resource): ) return {"result": "success"}, 200 + class DatasourceAuthDefaultApi(Resource): @setup_required @login_required @@ -266,6 +263,7 @@ class DatasourceAuthDefaultApi(Resource): ) return {"result": "success"}, 200 + class DatasourceUpdateProviderNameApi(Resource): @setup_required @login_required @@ -329,4 +327,4 @@ api.add_resource( api.add_resource( DatasourceUpdateProviderNameApi, "/auth/plugin/datasource//update-name", -) \ No newline at end of file +) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5377cbbb69..f9990260bc 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -496,9 +496,9 @@ class CredentialType(enum.StrEnum): @classmethod def of(cls, credential_type: str) -> "CredentialType": type_name = credential_type.lower() - if type_name == "api-key": + if type_name in {"api-key", "api_key"}: return cls.API_KEY - elif type_name == "oauth2": + elif type_name in {"oauth2", "oauth"}: return cls.OAUTH2 else: raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 4bf6fded4a..776b2cc7fe 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -293,7 +293,6 @@ class DatasourceProviderService: tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, ) .all() ) @@ -351,7 +350,7 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type ) for key, value in credentials.items(): if key in provider_credential_secret_variables: @@ -387,7 +386,7 @@ class DatasourceProviderService: provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id with Session(db.engine) as session: - lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key" + lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): db_provider_name = name or self.generate_next_datasource_provider_name( session=session, @@ -400,35 +399,36 @@ class DatasourceProviderService: if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0: raise ValueError("Authorization name is already exists") - credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, - user_id=current_user.id, - provider=provider_name, - plugin_id=plugin_id, - credentials=credentials, - ) - if credential_valid: - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type="api_key" - ) - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider( + try: + self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, - name=db_provider_name, + user_id=current_user.id, provider=provider_name, plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials, + credentials=credentials, ) - db.session.add(datasource_provider) - db.session.commit() - else: - raise CredentialsValidateFailedError() + except Exception as e: + raise ValueError(f"Failed to validate credentials: {str(e)}") - def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_name, + plugin_id=plugin_id, + auth_type=CredentialType.API_KEY.value, + encrypted_credentials=credentials, + ) + db.session.add(datasource_provider) + db.session.commit() + + def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]: """ Extract secret input form variables. @@ -439,9 +439,9 @@ class DatasourceProviderService: tenant_id=tenant_id, provider_id=provider_id ) credential_form_schemas = [] - if credential_type == "api_key": + if credential_type == CredentialType.API_KEY: credential_form_schemas = list(datasource_provider.declaration.credentials_schema) - elif credential_type == "oauth2": + elif credential_type == CredentialType.OAUTH2: if not datasource_provider.declaration.oauth_schema: raise ValueError("Datasource provider oauth schema not found") credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema) @@ -489,7 +489,7 @@ class DatasourceProviderService: credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", - credential_type=datasource_provider.auth_type, + credential_type=CredentialType.of(datasource_provider.auth_type), ) # Obfuscate provider credentials @@ -526,8 +526,7 @@ class DatasourceProviderService: tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" - f"{datasource_provider_id}/datasource/callback" + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" ) datasource_credentials.append( { @@ -559,7 +558,7 @@ class DatasourceProviderService: tenant_id, datasource_provider_id ), "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id), - "redirect_uri": redirect_uri + "redirect_uri": redirect_uri, } if datasource.declaration.oauth_schema else None, @@ -594,7 +593,7 @@ class DatasourceProviderService: credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", - credential_type=datasource_provider.auth_type, + credential_type=CredentialType.of(datasource_provider.auth_type), ) # Obfuscate provider credentials @@ -654,7 +653,7 @@ class DatasourceProviderService: if key in self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", - credential_type=datasource_provider.auth_type, + credential_type=CredentialType.of(datasource_provider.auth_type), ): if value == HIDDEN_VALUE and key in original_credentials: original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])