refactor: replace get_real_credential_by_id with get_datasource_credentials in multiple services for consistency

This commit is contained in:
Harry 2025-08-11 20:03:53 +08:00
parent 7f328328fb
commit 543f80ad5d
6 changed files with 99 additions and 94 deletions

View File

@ -121,7 +121,7 @@ class DataSourceNotionListApi(Resource):
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
@ -206,7 +206,7 @@ class DataSourceNotionApi(Resource):
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",

View File

@ -373,7 +373,7 @@ class NotionExtractor(BaseExtractor):
if not credential_id:
raise Exception(f"No credential id found for tenant {tenant_id}")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",

View File

@ -123,19 +123,12 @@ class DatasourceNode(BaseNode):
try:
datasource_provider_service = DatasourceProviderService()
if datasource_info.get("credential_id"):
credentials = datasource_provider_service.get_real_credential_by_id(
tenant_id=self.tenant_id,
credential_id=datasource_info.get("credential_id"),
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
)
else:
credentials = datasource_provider_service.get_default_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
)
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id"),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)

View File

@ -1,6 +1,8 @@
import logging
from typing import Any, Optional
import time
from typing import Any
from api.core.plugin.impl.oauth import OAuthHandler
from flask_login import current_user
from sqlalchemy.orm import Session
@ -41,37 +43,48 @@ class DatasourceProviderService:
).delete()
session.commit()
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credentials
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
def decrypt_datasource_provider_credentials(
self,
tenant_id: str,
datasource_provider: DatasourceProvider,
plugin_id: str,
provider: str,
) -> dict[str, Any]:
encrypted_credentials = datasource_provider.encrypted_credentials
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
decrypted_credentials = encrypted_credentials.copy()
for key, value in decrypted_credentials.items():
if key in credential_secret_variables:
decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return decrypted_credentials
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
def encrypt_datasource_provider_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
raw_credentials: dict[str, Any],
datasource_provider: DatasourceProvider,
) -> dict[str, Any]:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
)
encrypted_credentials = raw_credentials.copy()
for key, value in encrypted_credentials.items():
if key in provider_credential_secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
return encrypted_credentials
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
def get_real_credential_by_id(
self, tenant_id: str, credential_id: Optional[str], provider: str, plugin_id: str
def get_datasource_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str | None = None,
) -> dict[str, Any]:
"""
get credential by id
@ -90,49 +103,47 @@ class DatasourceProviderService:
)
if not datasource_provider:
return {}
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credential
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):
@ -454,7 +465,7 @@ class DatasourceProviderService:
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=20):
with redis_client.lock(lock, timeout=60):
db_provider_name = name
if not db_provider_name:
db_provider_name = self.generate_next_datasource_provider_name(
@ -599,9 +610,9 @@ class DatasourceProviderService:
return secret_input_form_variables
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
list datasource credentials with obfuscated sensitive fields.
:param tenant_id: workspace id
:param provider_id: provider id
@ -666,7 +677,7 @@ class DatasourceProviderService:
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.get_datasource_credentials(
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
@ -731,7 +742,8 @@ class DatasourceProviderService:
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{

View File

@ -522,7 +522,7 @@ class RagPipelineService:
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_credential_by_id(
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
@ -666,7 +666,7 @@ class RagPipelineService:
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_credential_by_id(
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),

View File

@ -125,7 +125,7 @@ class WebsiteService:
elif provider == "jinareader":
plugin_id = "langgenius/jina_datasource"
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_default_real_credential(
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,