refactor(api): replace dict type hints with Mapping for improved type safety

Updated type hints in several services to use Mapping instead of dict for better compatibility with various dictionary-like objects. Adjusted credential handling to ensure consistent encryption and decryption processes across ToolManager, DatasourceProviderService, ApiToolManageService, BuiltinToolManageService, and MCPToolManageService. This change enhances code clarity and adheres to strong typing practices.
This commit is contained in:
Harry 2025-10-29 18:10:23 +08:00
parent fb12f31df2
commit 9b5e5f0f50
5 changed files with 9 additions and 11 deletions

View File

@ -8,7 +8,6 @@ from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from yarl import URL from yarl import URL
@ -289,10 +288,8 @@ class ToolManager:
credentials=decrypted_credentials, credentials=decrypted_credentials,
) )
# update the credentials # update the credentials
builtin_provider.encrypted_credentials = ( builtin_provider.encrypted_credentials = json.dumps(
TypeAdapter(dict[str, Any]) encrypter.encrypt(refreshed_credentials.credentials)
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
) )
builtin_provider.expires_at = refreshed_credentials.expires_at builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit() db.session.commit()
@ -322,7 +319,7 @@ class ToolManager:
return api_provider.get_tool(tool_name).fork_tool_runtime( return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(credentials), credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,
) )

View File

@ -374,7 +374,7 @@ class DatasourceProviderService:
def get_tenant_oauth_client( def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> dict[str, Any] | None: ) -> Mapping[str, Any] | None:
""" """
get tenant oauth client get tenant oauth client
""" """
@ -434,7 +434,7 @@ class DatasourceProviderService:
) )
if tenant_oauth_client_params: if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params) return dict(encrypter.decrypt(tenant_oauth_client_params.client_params))
provider_controller = self.provider_manager.fetch_datasource_provider( provider_controller = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id) tenant_id=tenant_id, provider_id=str(datasource_provider_id)

View File

@ -306,7 +306,7 @@ class ApiToolManageService:
if name in masked_credentials and value == masked_credentials[name]: if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name] credentials[name] = original_credentials[name]
credentials = encrypter.encrypt(credentials) credentials = dict(encrypter.encrypt(credentials))
provider.credentials_str = json.dumps(credentials) provider.credentials_str = json.dumps(credentials)
db.session.add(provider) db.session.add(provider)

View File

@ -353,7 +353,7 @@ class BuiltinToolManageService:
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials)) decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider, provider=provider,
credentials=decrypt_credential, credentials=dict(decrypt_credential),
) )
credentials.append(credential_entity) credentials.append(credential_entity)
return credentials return credentials

View File

@ -1,6 +1,7 @@
import hashlib import hashlib
import json import json
import logging import logging
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import Any
@ -420,7 +421,7 @@ class MCPToolManageService:
return json.dumps({"content": icon, "background": icon_background}) return json.dumps({"content": icon, "background": icon_background})
return icon return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]: def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
"""Encrypt specified fields in a dictionary. """Encrypt specified fields in a dictionary.
Args: Args: