diff --git a/api/core/helper/provider_encryption.py b/api/core/helper/provider_encryption.py index 98130fed58..4f3ed75d86 100644 --- a/api/core/helper/provider_encryption.py +++ b/api/core/helper/provider_encryption.py @@ -1,4 +1,5 @@ import contextlib +from collections.abc import Mapping from copy import deepcopy from typing import Any, Optional, Protocol @@ -11,7 +12,7 @@ class ProviderConfigCache(Protocol): Interface for provider configuration cache operations """ - def get(self) -> Optional[dict]: + def get(self) -> Optional[dict[str, Any]]: """Get cached provider configuration""" ... @@ -39,19 +40,19 @@ class ProviderConfigEncrypter: self.config = config self.provider_config_cache = provider_config_cache - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]: """ deep copy data """ return deepcopy(data) - def encrypt(self, data: dict[str, str]) -> dict[str, str]: + def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]: """ encrypt tool credentials with tenant id return a deep copy of credentials with encrypted values """ - data = self._deep_copy(data) + data = dict(self._deep_copy(data)) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() @@ -66,13 +67,13 @@ class ProviderConfigEncrypter: return data - def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]: """ mask credentials return a deep copy of credentials with masked values """ - data = self._deep_copy(data) + data = dict(self._deep_copy(data)) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() @@ -91,10 +92,10 @@ class ProviderConfigEncrypter: return data - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]: return self.mask_credentials(data) - def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]: """ decrypt tool credentials with tenant id @@ -104,7 +105,7 @@ class ProviderConfigEncrypter: if cached_credentials: return cached_credentials - data = self._deep_copy(data) + data = dict(self._deep_copy(data)) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() for credential in self.config: @@ -120,7 +121,7 @@ class ProviderConfigEncrypter: data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - self.provider_config_cache.set(data) + self.provider_config_cache.set(dict(data)) return data diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9e0040a0d4..9b774c9c6d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -830,7 +830,7 @@ class ToolManager: controller=controller, ) - masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) + masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index f17c9a95a0..ce7d0a27ce 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -110,9 +110,8 @@ class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller): webhook_data = payload.get("webhook_data", {}) workflow_inputs = WebhookService.build_workflow_inputs(webhook_data) - workflow_args = { + workflow_args: Mapping[str, Any] = { "inputs": workflow_inputs or {}, - "query": "", "files": [], } return TriggerDebugEvent(workflow_args=workflow_args, node_id=self.node_id) diff --git a/api/core/trigger/entities/api_entities.py b/api/core/trigger/entities/api_entities.py index 9f969b9f79..7e77e89932 100644 --- a/api/core/trigger/entities/api_entities.py +++ b/api/core/trigger/entities/api_entities.py @@ -19,10 +19,10 @@ class TriggerProviderSubscriptionApiEntity(BaseModel): name: str = Field(description="The name of the subscription") provider: str = Field(description="The provider id of the subscription") credential_type: CredentialType = Field(description="The type of the credential") - credentials: dict = Field(description="The credentials of the subscription") + credentials: dict[str, Any] = Field(description="The credentials of the subscription") endpoint: str = Field(description="The endpoint of the subscription") - parameters: dict = Field(description="The parameters of the subscription") - properties: dict = Field(description="The properties of the subscription") + parameters: dict[str, Any] = Field(description="The parameters of the subscription") + properties: dict[str, Any] = Field(description="The properties of the subscription") workflows_in_use: int = Field(description="The number of workflows using this subscription") diff --git a/api/models/tools.py b/api/models/tools.py index aec53da50c..667e049d0b 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -415,7 +415,7 @@ class MCPToolProvider(TypeBase): # First decrypt, then mask decrypted_headers = encrypter_instance.decrypt(headers_data) - result = encrypter_instance.mask_tool_credentials(decrypted_headers) + result = encrypter_instance.mask_plugin_credentials(decrypted_headers) return result except Exception: return {} diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index b53320bf5e..1e018af19f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -390,7 +390,7 @@ class DatasourceProviderService: if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) if mask: - return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) + return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) else: return encrypter.decrypt(tenant_oauth_client_params.client_params) return None diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index bb024cc846..f0a0bcde1b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -300,7 +300,7 @@ class ApiToolManageService: ) original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_tool_credentials(original_credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: @@ -417,7 +417,7 @@ class ApiToolManageService: ) decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 01d13d8e5b..db6d510d81 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -350,7 +350,7 @@ class BuiltinToolManageService: encrypter, _ = BuiltinToolManageService.create_tool_encrypter( tenant_id, provider, provider.provider, provider_controller ) - decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) + decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials)) credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( provider=provider, credentials=decrypt_credential, @@ -724,4 +724,4 @@ class BuiltinToolManageService: cache=NoOpProviderCredentialCache(), ) - return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) + return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 2d55420d40..ed04f41ba3 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -162,7 +162,7 @@ class ToolTransformService: ) # decrypt the credentials and mask the credentials decrypted_credentials = encrypter.decrypt(data=credentials) - masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) + masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -316,7 +316,7 @@ class ToolTransformService: # decrypt the credentials and mask the credentials decrypted_credentials = encrypter.decrypt(data=credentials) - masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) + masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 8395d88b1d..77336c08d8 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -113,7 +113,7 @@ class TriggerProviderService: subscription_id: Optional[str] = None, credential_expires_at: int = -1, expires_at: int = -1, - ) -> dict: + ) -> Mapping[str, Any]: """ Add a new trigger provider with credentials. Supports multiple credential instances per provider. @@ -187,7 +187,10 @@ class TriggerProviderService: session.add(subscription) session.commit() - return {"result": "success", "id": str(subscription.id)} + return { + "result": "success", + "id": str(subscription.id), + } except Exception as e: logger.exception("Failed to add trigger provider") @@ -278,7 +281,7 @@ class TriggerProviderService: cls, tenant_id: str, subscription_id: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Refresh OAuth token for a trigger provider. @@ -372,7 +375,7 @@ class TriggerProviderService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - oauth_params = encrypter.decrypt(tenant_client.oauth_params) + oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params)) return oauth_params is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id) @@ -399,9 +402,9 @@ class TriggerProviderService: cls, tenant_id: str, provider_id: TriggerProviderID, - client_params: Optional[dict] = None, + client_params: Optional[Mapping[str, Any]] = None, enabled: Optional[bool] = None, - ) -> dict: + ) -> Mapping[str, Any]: """ Save or update custom OAuth client parameters for a trigger provider. @@ -449,8 +452,8 @@ class TriggerProviderService: ) # Handle hidden values - original_params = encrypter.decrypt(custom_client.oauth_params) - new_params: dict = { + original_params = encrypter.decrypt(dict(custom_client.oauth_params)) + new_params: dict[str, Any] = { key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) for key, value in client_params.items() } @@ -466,7 +469,7 @@ class TriggerProviderService: return {"result": "success"} @classmethod - def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: + def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]: """ Get custom OAuth client parameters for a trigger provider. @@ -500,10 +503,10 @@ class TriggerProviderService: cache=NoOpProviderCredentialCache(), ) - return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params)) + return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params))) @classmethod - def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: + def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]: """ Delete custom OAuth client parameters for a trigger provider.