This commit is contained in:
Yeuoly 2025-10-18 19:41:53 +08:00
parent 6d09330f98
commit 5d6b9b0cb1
10 changed files with 38 additions and 35 deletions

View File

@ -1,4 +1,5 @@
import contextlib
from collections.abc import Mapping
from copy import deepcopy
from typing import Any, Optional, Protocol
@ -11,7 +12,7 @@ class ProviderConfigCache(Protocol):
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
def get(self) -> Optional[dict[str, Any]]:
"""Get cached provider configuration"""
...
@ -39,19 +40,19 @@ class ProviderConfigEncrypter:
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
@ -66,13 +67,13 @@ class ProviderConfigEncrypter:
return data
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
mask credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
@ -91,10 +92,10 @@ class ProviderConfigEncrypter:
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
return self.mask_credentials(data)
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
decrypt tool credentials with tenant id
@ -104,7 +105,7 @@ class ProviderConfigEncrypter:
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
@ -120,7 +121,7 @@ class ProviderConfigEncrypter:
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
self.provider_config_cache.set(dict(data))
return data

View File

@ -830,7 +830,7 @@ class ToolManager:
controller=controller,
)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

View File

@ -110,9 +110,8 @@ class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller):
webhook_data = payload.get("webhook_data", {})
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
workflow_args = {
workflow_args: Mapping[str, Any] = {
"inputs": workflow_inputs or {},
"query": "",
"files": [],
}
return TriggerDebugEvent(workflow_args=workflow_args, node_id=self.node_id)

View File

@ -19,10 +19,10 @@ class TriggerProviderSubscriptionApiEntity(BaseModel):
name: str = Field(description="The name of the subscription")
provider: str = Field(description="The provider id of the subscription")
credential_type: CredentialType = Field(description="The type of the credential")
credentials: dict = Field(description="The credentials of the subscription")
credentials: dict[str, Any] = Field(description="The credentials of the subscription")
endpoint: str = Field(description="The endpoint of the subscription")
parameters: dict = Field(description="The parameters of the subscription")
properties: dict = Field(description="The properties of the subscription")
parameters: dict[str, Any] = Field(description="The parameters of the subscription")
properties: dict[str, Any] = Field(description="The properties of the subscription")
workflows_in_use: int = Field(description="The number of workflows using this subscription")

View File

@ -415,7 +415,7 @@ class MCPToolProvider(TypeBase):
# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
result = encrypter_instance.mask_plugin_credentials(decrypted_headers)
return result
except Exception:
return {}

View File

@ -390,7 +390,7 @@ class DatasourceProviderService:
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None

View File

@ -300,7 +300,7 @@ class ApiToolManageService:
)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
@ -417,7 +417,7 @@ class ApiToolManageService:
)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]

View File

@ -350,7 +350,7 @@ class BuiltinToolManageService:
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
@ -724,4 +724,4 @@ class BuiltinToolManageService:
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

View File

@ -162,7 +162,7 @@ class ToolTransformService:
)
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@ -316,7 +316,7 @@ class ToolTransformService:
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials

View File

@ -113,7 +113,7 @@ class TriggerProviderService:
subscription_id: Optional[str] = None,
credential_expires_at: int = -1,
expires_at: int = -1,
) -> dict:
) -> Mapping[str, Any]:
"""
Add a new trigger provider with credentials.
Supports multiple credential instances per provider.
@ -187,7 +187,10 @@ class TriggerProviderService:
session.add(subscription)
session.commit()
return {"result": "success", "id": str(subscription.id)}
return {
"result": "success",
"id": str(subscription.id),
}
except Exception as e:
logger.exception("Failed to add trigger provider")
@ -278,7 +281,7 @@ class TriggerProviderService:
cls,
tenant_id: str,
subscription_id: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Refresh OAuth token for a trigger provider.
@ -372,7 +375,7 @@ class TriggerProviderService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
return oauth_params
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
@ -399,9 +402,9 @@ class TriggerProviderService:
cls,
tenant_id: str,
provider_id: TriggerProviderID,
client_params: Optional[dict] = None,
client_params: Optional[Mapping[str, Any]] = None,
enabled: Optional[bool] = None,
) -> dict:
) -> Mapping[str, Any]:
"""
Save or update custom OAuth client parameters for a trigger provider.
@ -449,8 +452,8 @@ class TriggerProviderService:
)
# Handle hidden values
original_params = encrypter.decrypt(custom_client.oauth_params)
new_params: dict = {
original_params = encrypter.decrypt(dict(custom_client.oauth_params))
new_params: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
@ -466,7 +469,7 @@ class TriggerProviderService:
return {"result": "success"}
@classmethod
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
"""
Get custom OAuth client parameters for a trigger provider.
@ -500,10 +503,10 @@ class TriggerProviderService:
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params)))
@classmethod
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
"""
Delete custom OAuth client parameters for a trigger provider.