mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
refactor
This commit is contained in:
parent
6d09330f98
commit
5d6b9b0cb1
@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
@ -11,7 +12,7 @@ class ProviderConfigCache(Protocol):
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
def get(self) -> Optional[dict[str, Any]]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
@ -39,19 +40,19 @@ class ProviderConfigEncrypter:
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
@ -66,13 +67,13 @@ class ProviderConfigEncrypter:
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
@ -91,10 +92,10 @@ class ProviderConfigEncrypter:
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
@ -104,7 +105,7 @@ class ProviderConfigEncrypter:
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
data = dict(self._deep_copy(data))
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
@ -120,7 +121,7 @@ class ProviderConfigEncrypter:
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
self.provider_config_cache.set(dict(data))
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@ -830,7 +830,7 @@ class ToolManager:
|
||||
controller=controller,
|
||||
)
|
||||
|
||||
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
|
||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
|
||||
@ -110,9 +110,8 @@ class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
webhook_data = payload.get("webhook_data", {})
|
||||
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
workflow_args = {
|
||||
workflow_args: Mapping[str, Any] = {
|
||||
"inputs": workflow_inputs or {},
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
return TriggerDebugEvent(workflow_args=workflow_args, node_id=self.node_id)
|
||||
|
||||
@ -19,10 +19,10 @@ class TriggerProviderSubscriptionApiEntity(BaseModel):
|
||||
name: str = Field(description="The name of the subscription")
|
||||
provider: str = Field(description="The provider id of the subscription")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
credentials: dict = Field(description="The credentials of the subscription")
|
||||
credentials: dict[str, Any] = Field(description="The credentials of the subscription")
|
||||
endpoint: str = Field(description="The endpoint of the subscription")
|
||||
parameters: dict = Field(description="The parameters of the subscription")
|
||||
properties: dict = Field(description="The properties of the subscription")
|
||||
parameters: dict[str, Any] = Field(description="The parameters of the subscription")
|
||||
properties: dict[str, Any] = Field(description="The properties of the subscription")
|
||||
workflows_in_use: int = Field(description="The number of workflows using this subscription")
|
||||
|
||||
|
||||
|
||||
@ -415,7 +415,7 @@ class MCPToolProvider(TypeBase):
|
||||
|
||||
# First decrypt, then mask
|
||||
decrypted_headers = encrypter_instance.decrypt(headers_data)
|
||||
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
|
||||
result = encrypter_instance.mask_plugin_credentials(decrypted_headers)
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@ -390,7 +390,7 @@ class DatasourceProviderService:
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
if mask:
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
else:
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
|
||||
@ -300,7 +300,7 @@ class ApiToolManageService:
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
@ -417,7 +417,7 @@ class ApiToolManageService:
|
||||
)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
@ -350,7 +350,7 @@ class BuiltinToolManageService:
|
||||
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, provider, provider.provider, provider_controller
|
||||
)
|
||||
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
@ -724,4 +724,4 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
|
||||
@ -162,7 +162,7 @@ class ToolTransformService:
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
@ -316,7 +316,7 @@ class ToolTransformService:
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ class TriggerProviderService:
|
||||
subscription_id: Optional[str] = None,
|
||||
credential_expires_at: int = -1,
|
||||
expires_at: int = -1,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Add a new trigger provider with credentials.
|
||||
Supports multiple credential instances per provider.
|
||||
@ -187,7 +187,10 @@ class TriggerProviderService:
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
||||
return {"result": "success", "id": str(subscription.id)}
|
||||
return {
|
||||
"result": "success",
|
||||
"id": str(subscription.id),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add trigger provider")
|
||||
@ -278,7 +281,7 @@ class TriggerProviderService:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh OAuth token for a trigger provider.
|
||||
|
||||
@ -372,7 +375,7 @@ class TriggerProviderService:
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
|
||||
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
|
||||
return oauth_params
|
||||
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
@ -399,9 +402,9 @@ class TriggerProviderService:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
client_params: Optional[dict] = None,
|
||||
client_params: Optional[Mapping[str, Any]] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save or update custom OAuth client parameters for a trigger provider.
|
||||
|
||||
@ -449,8 +452,8 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
# Handle hidden values
|
||||
original_params = encrypter.decrypt(custom_client.oauth_params)
|
||||
new_params: dict = {
|
||||
original_params = encrypter.decrypt(dict(custom_client.oauth_params))
|
||||
new_params: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
@ -466,7 +469,7 @@ class TriggerProviderService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get custom OAuth client parameters for a trigger provider.
|
||||
|
||||
@ -500,10 +503,10 @@ class TriggerProviderService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params)))
|
||||
|
||||
@classmethod
|
||||
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Delete custom OAuth client parameters for a trigger provider.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user