diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index dac406360f..c8685dedbf 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -54,7 +54,21 @@ class DatasourceProviderService: ) if not datasource_provider: return {} - return datasource_provider.encrypted_credentials + + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=CredentialType.of(datasource_provider.auth_type), + ) + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + return copy_credentials def get_real_credential_by_id( self, tenant_id: str, credential_id: str, provider: str, plugin_id: str @@ -367,21 +381,52 @@ class DatasourceProviderService: update datasource oauth provider """ with Session(db.engine) as session: - target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() - if target_provider is None: - raise ValueError("provider not found") + lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" + with redis_client.lock(lock, timeout=20): + target_provider = ( + session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + ) + if target_provider is None: + raise ValueError("provider not found") - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 - ) - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(tenant_id, value) + db_provider_name = name + if not db_provider_name: + db_provider_name = target_provider.name + else: + name_conflict = ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + auth_type=CredentialType.OAUTH2.value, + ) + .count() + ) + if name_conflict > 0: + db_provider_name = generate_incremental_name( + [ + provider.name + for provider in session.query(DatasourceProvider).filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + ) + ], + db_provider_name, + ) - target_provider.encrypted_credentials = credentials - target_provider.avatar_url = avatar_url or target_provider.avatar_url - target_provider.name = name or target_provider.name - session.commit() + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2 + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + target_provider.encrypted_credentials = credentials + target_provider.avatar_url = avatar_url or target_provider.avatar_url + session.commit() def add_datasource_oauth_provider( self, @@ -672,7 +717,9 @@ class DatasourceProviderService: credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" + redirect_uri = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" + ) datasource_credentials.append( { "provider": datasource.provider, diff --git a/web/app/components/header/account-setting/data-source-page-new/card.tsx b/web/app/components/header/account-setting/data-source-page-new/card.tsx index 0e0d5f4ea4..74e188d350 100644 --- a/web/app/components/header/account-setting/data-source-page-new/card.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/card.tsx @@ -1,6 +1,7 @@ import { memo, useCallback, + useRef, } from 'react' import { useTranslation } from 'react-i18next' import Item from './item' @@ -17,6 +18,8 @@ import { } from '@/app/components/plugins/plugin-auth' import { useDataSourceAuthUpdate } from './hooks' import Confirm from '@/app/components/base/confirm' +import { useGetDataSourceOAuthUrl } from '@/service/use-datasource' +import { openOAuthPopup } from '@/hooks/use-oauth' type CardProps = { item: DataSourceAuth @@ -55,6 +58,20 @@ const Card = ({ closeConfirm, pendingOperationCredentialId, } = usePluginAuthAction(pluginPayload, handleAuthUpdate) + const changeCredentialIdRef = useRef(undefined) + const { + mutateAsync: getPluginOAuthUrl, + } = useGetDataSourceOAuthUrl(pluginPayload.provider) + const handleOAuth = useCallback(async () => { + const { authorization_url } = await getPluginOAuthUrl(changeCredentialIdRef.current) + + if (authorization_url) { + openOAuthPopup( + authorization_url, + handleAuthUpdate, + ) + } + }, [getPluginOAuthUrl, handleAuthUpdate]) const handleAction = useCallback(( action: string, credentialItem: DataSourceCredential, @@ -78,6 +95,11 @@ const Card = ({ if (action === 'rename') handleRename(renamePayload as any) + + if (action === 'change') { + changeCredentialIdRef.current = credentialItem.id + handleOAuth() + } }, [ openConfirm, handleEdit, diff --git a/web/service/use-datasource.ts b/web/service/use-datasource.ts index cbaa14a7e0..b923838f86 100644 --- a/web/service/use-datasource.ts +++ b/web/service/use-datasource.ts @@ -1,4 +1,7 @@ -import { useQuery } from '@tanstack/react-query' +import { + useMutation, + useQuery, +} from '@tanstack/react-query' import { get } from './base' import { useInvalid } from './use-base' import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' @@ -31,3 +34,18 @@ export const useInvalidDefaultDataSourceListAuth = ( ) => { return useInvalid([NAME_SPACE, 'default-list']) } +export const useGetDataSourceOAuthUrl = ( + provider: string, +) => { + return useMutation({ + mutationKey: [NAME_SPACE, 'oauth-url', provider], + mutationFn: (credentialId?: string) => { + return get< + { + authorization_url: string + state: string + context_id: string + }>(`/oauth/plugin/${provider}/datasource/get-authorization-url?credential_id=${credentialId}`) + }, + }) +}