diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 077282d959..88b24c6985 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -121,7 +121,7 @@ class DataSourceNotionListApi(Resource): if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_real_credential_by_id( + credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_user.current_tenant_id, credential_id=credential_id, provider="notion_datasource", @@ -206,7 +206,7 @@ class DataSourceNotionApi(Resource): if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_real_credential_by_id( + credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_user.current_tenant_id, credential_id=credential_id, provider="notion_datasource", diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 86932c00b1..ebbb1e06e5 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -373,7 +373,7 @@ class NotionExtractor(BaseExtractor): if not credential_id: raise Exception(f"No credential id found for tenant {tenant_id}") datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_real_credential_by_id( + credential = datasource_provider_service.get_datasource_credentials( tenant_id=tenant_id, credential_id=credential_id, provider="notion_datasource", diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d4f1c4f392..f4707f42fd 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -123,19 +123,12 @@ class DatasourceNode(BaseNode): try: datasource_provider_service = DatasourceProviderService() - if datasource_info.get("credential_id"): - credentials = datasource_provider_service.get_real_credential_by_id( - tenant_id=self.tenant_id, - credential_id=datasource_info.get("credential_id"), - provider=node_data.provider_name, - plugin_id=node_data.plugin_id, - ) - else: - credentials = datasource_provider_service.get_default_credentials( - tenant_id=self.tenant_id, - provider=node_data.provider_name, - plugin_id=node_data.plugin_id, - ) + credentials = datasource_provider_service.get_datasource_credentials( + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + credential_id=datasource_info.get("credential_id"), + ) match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index ef29654a35..e36d09110f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,6 +1,8 @@ import logging -from typing import Any, Optional +import time +from typing import Any +from api.core.plugin.impl.oauth import OAuthHandler from flask_login import current_user from sqlalchemy.orm import Session @@ -41,37 +43,48 @@ class DatasourceProviderService: ).delete() session.commit() - def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: - """ - get default credentials - """ - with Session(db.engine) as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) - .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() - ) - if not datasource_provider: - return {} + def decrypt_datasource_provider_credentials( + self, + tenant_id: str, + datasource_provider: DatasourceProvider, + plugin_id: str, + provider: str, + ) -> dict[str, Any]: + encrypted_credentials = datasource_provider.encrypted_credentials + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=CredentialType.of(datasource_provider.auth_type), + ) + decrypted_credentials = encrypted_credentials.copy() + for key, value in decrypted_credentials.items(): + if key in credential_secret_variables: + decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value) + return decrypted_credentials - encrypted_credentials = datasource_provider.encrypted_credentials - # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}", - credential_type=CredentialType.of(datasource_provider.auth_type), - ) + def encrypt_datasource_provider_credentials( + self, + tenant_id: str, + provider: str, + plugin_id: str, + raw_credentials: dict[str, Any], + datasource_provider: DatasourceProvider, + ) -> dict[str, Any]: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type + ) + encrypted_credentials = raw_credentials.copy() + for key, value in encrypted_credentials.items(): + if key in provider_credential_secret_variables: + encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value) + return encrypted_credentials - # Obfuscate provider credentials - copy_credentials = encrypted_credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) - return copy_credentials - - def get_real_credential_by_id( - self, tenant_id: str, credential_id: Optional[str], provider: str, plugin_id: str + def get_datasource_credentials( + self, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str | None = None, ) -> dict[str, Any]: """ get credential by id @@ -90,49 +103,47 @@ class DatasourceProviderService: ) if not datasource_provider: return {} - encrypted_credentials = datasource_provider.encrypted_credentials - # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( + # refresh the credentials + if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): + decrypted_credentials = self.decrypt_datasource_provider_credentials( + tenant_id=tenant_id, + datasource_provider=datasource_provider, + plugin_id=plugin_id, + provider=provider, + ) + datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}") + provider_name = datasource_provider_id.provider_name + redirect_uri = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" + f"{datasource_provider_id}/datasource/callback" + ) + system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id) + refreshed_credentials = OAuthHandler().refresh_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + plugin_id=datasource_provider_id.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials( + tenant_id=tenant_id, + raw_credentials=refreshed_credentials.credentials, + provider=provider, + plugin_id=plugin_id, + datasource_provider=datasource_provider, + ) + datasource_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + + return self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}", - credential_type=CredentialType.of(datasource_provider.auth_type), + datasource_provider=datasource_provider, + plugin_id=plugin_id, + provider=provider, ) - # Obfuscate provider credentials - copy_credentials = encrypted_credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) - return copy_credentials - - def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: - """ - get default credential - """ - with Session(db.engine) as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) - .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() - ) - if not datasource_provider: - return {} - encrypted_credentials = datasource_provider.encrypted_credentials - # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}", - credential_type=CredentialType.of(datasource_provider.auth_type), - ) - - # Obfuscate provider credentials - copy_credentials = encrypted_credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) - return copy_credentials - def update_datasource_provider_name( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str ): @@ -454,7 +465,7 @@ class DatasourceProviderService: credential_type = CredentialType.OAUTH2 with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" - with redis_client.lock(lock, timeout=20): + with redis_client.lock(lock, timeout=60): db_provider_name = name if not db_provider_name: db_provider_name = self.generate_next_datasource_provider_name( @@ -599,9 +610,9 @@ class DatasourceProviderService: return secret_input_form_variables - def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ - get datasource credentials. + list datasource credentials with obfuscated sensitive fields. :param tenant_id: workspace id :param provider_id: provider id @@ -666,7 +677,7 @@ class DatasourceProviderService: datasource_credentials = [] for datasource in datasources: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") - credentials = self.get_datasource_credentials( + credentials = self.list_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) redirect_uri = ( @@ -731,7 +742,8 @@ class DatasourceProviderService: tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" + f"{datasource_provider_id}/datasource/callback" ) datasource_credentials.append( { diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 363ff4f0e5..9bd17778bd 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -522,7 +522,7 @@ class RagPipelineService: datasource_type=DatasourceProviderType(datasource_type), ) datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_credential_by_id( + credentials = datasource_provider_service.get_datasource_credentials( tenant_id=pipeline.tenant_id, provider=datasource_node_data.get("provider_name"), plugin_id=datasource_node_data.get("plugin_id"), @@ -666,7 +666,7 @@ class RagPipelineService: datasource_type=DatasourceProviderType(datasource_type), ) datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_credential_by_id( + credentials = datasource_provider_service.get_datasource_credentials( tenant_id=pipeline.tenant_id, provider=datasource_node_data.get("provider_name"), plugin_id=datasource_node_data.get("plugin_id"), diff --git a/api/services/website_service.py b/api/services/website_service.py index 83256daf3d..854c213d91 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -125,7 +125,7 @@ class WebsiteService: elif provider == "jinareader": plugin_id = "langgenius/jina_datasource" datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_default_real_credential( + credential = datasource_provider_service.get_datasource_credentials( tenant_id=tenant_id, provider=provider, plugin_id=plugin_id,