mirror of https://github.com/langgenius/dify.git
refactor: replace get_real_credential_by_id with get_datasource_credentials in multiple services for consistency
This commit is contained in:
parent
7f328328fb
commit
543f80ad5d
|
|
@ -121,7 +121,7 @@ class DataSourceNotionListApi(Resource):
|
|||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
|
|
@ -206,7 +206,7 @@ class DataSourceNotionApi(Resource):
|
|||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
|
|
|
|||
|
|
@ -373,7 +373,7 @@ class NotionExtractor(BaseExtractor):
|
|||
if not credential_id:
|
||||
raise Exception(f"No credential id found for tenant {tenant_id}")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
|
|
|
|||
|
|
@ -123,19 +123,12 @@ class DatasourceNode(BaseNode):
|
|||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if datasource_info.get("credential_id"):
|
||||
credentials = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=self.tenant_id,
|
||||
credential_id=datasource_info.get("credential_id"),
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
else:
|
||||
credentials = datasource_provider_service.get_default_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id"),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Any, Optional
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from api.core.plugin.impl.oauth import OAuthHandler
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -41,37 +43,48 @@ class DatasourceProviderService:
|
|||
).delete()
|
||||
session.commit()
|
||||
|
||||
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get default credentials
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
datasource_provider: DatasourceProvider,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
) -> dict[str, Any]:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
decrypted_credentials = encrypted_credentials.copy()
|
||||
for key, value in decrypted_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return decrypted_credentials
|
||||
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
def encrypt_datasource_provider_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
raw_credentials: dict[str, Any],
|
||||
datasource_provider: DatasourceProvider,
|
||||
) -> dict[str, Any]:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
|
||||
)
|
||||
encrypted_credentials = raw_credentials.copy()
|
||||
for key, value in encrypted_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
return encrypted_credentials
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return copy_credentials
|
||||
|
||||
def get_real_credential_by_id(
|
||||
self, tenant_id: str, credential_id: Optional[str], provider: str, plugin_id: str
|
||||
def get_datasource_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
get credential by id
|
||||
|
|
@ -90,49 +103,47 @@ class DatasourceProviderService:
|
|||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
# refresh the credentials
|
||||
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
|
||||
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
|
||||
refreshed_credentials = OAuthHandler().refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
raw_credentials=refreshed_credentials.credentials,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return copy_credentials
|
||||
|
||||
def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get default credential
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return copy_credentials
|
||||
|
||||
def update_datasource_provider_name(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
|
||||
):
|
||||
|
|
@ -454,7 +465,7 @@ class DatasourceProviderService:
|
|||
credential_type = CredentialType.OAUTH2
|
||||
with Session(db.engine) as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
with redis_client.lock(lock, timeout=60):
|
||||
db_provider_name = name
|
||||
if not db_provider_name:
|
||||
db_provider_name = self.generate_next_datasource_provider_name(
|
||||
|
|
@ -599,9 +610,9 @@ class DatasourceProviderService:
|
|||
|
||||
return secret_input_form_variables
|
||||
|
||||
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
||||
def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
||||
"""
|
||||
get datasource credentials.
|
||||
list datasource credentials with obfuscated sensitive fields.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_id: provider id
|
||||
|
|
@ -666,7 +677,7 @@ class DatasourceProviderService:
|
|||
datasource_credentials = []
|
||||
for datasource in datasources:
|
||||
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
||||
credentials = self.get_datasource_credentials(
|
||||
credentials = self.list_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
|
|
@ -731,7 +742,8 @@ class DatasourceProviderService:
|
|||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
datasource_credentials.append(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -522,7 +522,7 @@ class RagPipelineService:
|
|||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_credential_by_id(
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
|
|
@ -666,7 +666,7 @@ class RagPipelineService:
|
|||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_credential_by_id(
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class WebsiteService:
|
|||
elif provider == "jinareader":
|
||||
plugin_id = "langgenius/jina_datasource"
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_default_real_credential(
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
|
|
|
|||
Loading…
Reference in New Issue