From 4dab128900702eb39ae77cf1f4a8532583d7d060 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 30 Jul 2025 15:52:59 +0800 Subject: [PATCH] feat: oauth --- api/services/datasource_provider_service.py | 57 ++++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index dac406360f..ab7da91032 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -367,21 +367,52 @@ class DatasourceProviderService: update datasource oauth provider """ with Session(db.engine) as session: - target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() - if target_provider is None: - raise ValueError("provider not found") + lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" + with redis_client.lock(lock, timeout=20): + target_provider = ( + session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + ) + if target_provider is None: + raise ValueError("provider not found") - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 - ) - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(tenant_id, value) + db_provider_name = name + if not db_provider_name: + db_provider_name = target_provider.name + else: + name_conflict = ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + auth_type=CredentialType.OAUTH2.value, + ) + .count() + ) + if name_conflict > 0: + db_provider_name = generate_incremental_name( + [ + provider.name + for provider in session.query(DatasourceProvider).filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + ) + ], + db_provider_name, + ) - target_provider.encrypted_credentials = credentials - target_provider.avatar_url = avatar_url or target_provider.avatar_url - target_provider.name = name or target_provider.name - session.commit() + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + target_provider.encrypted_credentials = credentials + target_provider.avatar_url = avatar_url or target_provider.avatar_url + session.commit() def add_datasource_oauth_provider( self,