refactor(api): replace dict type hints with Mapping for improved type safety

Updated type hints in several services to use Mapping instead of dict for better compatibility with various dictionary-like objects. Adjusted credential handling to ensure consistent encryption and decryption processes across ToolManager, DatasourceProviderService, ApiToolManageService, BuiltinToolManageService, and MCPToolManageService. This change enhances code clarity and adheres to strong typing practices.
This commit is contained in:
Harry 2025-10-29 18:10:23 +08:00
parent fb12f31df2
commit 9b5e5f0f50
5 changed files with 9 additions and 11 deletions

View File

@ -8,7 +8,6 @@ from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session
from yarl import URL
@ -289,10 +288,8 @@ class ToolManager:
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
@ -322,7 +319,7 @@ class ToolManager:
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(credentials),
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)

View File

@ -374,7 +374,7 @@ class DatasourceProviderService:
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> dict[str, Any] | None:
) -> Mapping[str, Any] | None:
"""
get tenant oauth client
"""
@ -434,7 +434,7 @@ class DatasourceProviderService:
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return dict(encrypter.decrypt(tenant_oauth_client_params.client_params))
provider_controller = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)

View File

@ -306,7 +306,7 @@ class ApiToolManageService:
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = encrypter.encrypt(credentials)
credentials = dict(encrypter.encrypt(credentials))
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)

View File

@ -353,7 +353,7 @@ class BuiltinToolManageService:
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
credentials=dict(decrypt_credential),
)
credentials.append(credential_entity)
return credentials

View File

@ -1,6 +1,7 @@
import hashlib
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any
@ -420,7 +421,7 @@ class MCPToolManageService:
return json.dumps({"content": icon, "background": icon_background})
return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
"""Encrypt specified fields in a dictionary.
Args: