refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077)

This commit is contained in:
wdeveloper16 2026-04-13 19:09:25 +02:00 committed by GitHub
parent 7056d2ae99
commit 4e0273bb28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 32 deletions

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str id: str
position: int position: int
data_source_type: str data_source_type: str
data_source_info: dict | None = None data_source_info: dict[str, Any] | None = None
name: str name: str
indexing_status: str indexing_status: str
error: str | None = None error: str | None = None

View File

@ -6,6 +6,7 @@ import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import ( from graphon.model_runtime.entities.provider_entities import (
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime) return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id) return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
""" """
Get current credentials. Get current credentials.
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none() return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None: def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
""" """
Get a specific provider credential by ID. Get a specific provider credential by ID.
:param credential_id: Credential ID :param credential_id: Credential ID
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id) stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None: def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
""" """
Get provider credentials. Get provider credentials.
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [], else [],
) )
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name) provider_names.append(model_provider_id.provider_name)
return provider_names return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None): def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential( def update_provider_credential(
self, self,
credentials: dict, credentials: dict[str, Any],
credential_id: str, credential_id: str,
credential_name: str | None, credential_name: str | None,
): ):
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential( def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str self, model_type: ModelType, model: str, credential_id: str
) -> dict | None: ) -> dict[str, Any] | None:
""" """
Get a specific provider credential by ID. Get a specific provider credential by ID.
:param credential_id: Credential ID :param credential_id: Credential ID
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id) stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
""" """
Get custom model credentials. Get custom model credentials.
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self, self,
model_type: ModelType, model_type: ModelType,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
credential_id: str = "", credential_id: str = "",
session: Session | None = None, session: Session | None = None,
): ):
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session) return _validate(new_session)
def create_custom_model_credential( def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None: ) -> None:
""" """
Create a custom model credential. Create a custom model credential.
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise raise
def update_custom_model_credential( def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None: ) -> None:
""" """
Update a custom model credential. Update a custom model credential.
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM # Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
""" """
Get model schema Get model schema
""" """
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
""" """
Obfuscated credentials. Obfuscated credentials.

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Union from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool enabled: bool
current_quota_type: ProviderQuotaType | None = None current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = [] quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel): class CustomProviderConfiguration(BaseModel):
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration. Model class for provider custom configuration.
""" """
credentials: dict credentials: dict[str, Any]
current_credential_id: str | None = None current_credential_id: str | None = None
current_credential_name: str | None = None current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = [] available_credentials: list[CredentialConfiguration] = []
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str model: str
model_type: ModelType model_type: ModelType
credentials: dict | None credentials: dict[str, Any] | None
current_credential_id: str | None = None current_credential_id: str | None = None
current_credential_name: str | None = None current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = [] available_model_credentials: list[CredentialConfiguration] = []

View File

@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
) -> AIModelEntity | None: ) -> AIModelEntity | None:
""" """
Get model schema Get model schema
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
) -> bool: ) -> bool:
""" """
validate the credentials of the provider validate the credentials of the provider
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stop: list[str] | None = None,
stream: bool = True, stream: bool = True,
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None, tools: list[PromptMessageTool] | None = None,
) -> int: ) -> int:
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
texts: list[str], texts: list[str],
input_type: str, input_type: str,
) -> EmbeddingResult: ) -> EmbeddingResult:
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
documents: list[dict], documents: list[dict],
input_type: str, input_type: str,
) -> EmbeddingResult: ) -> EmbeddingResult:
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
texts: list[str], texts: list[str],
) -> list[int]: ) -> list[int]:
""" """
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
query: str, query: str,
docs: list[str], docs: list[str],
score_threshold: float | None = None, score_threshold: float | None = None,
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
query: MultimodalRerankInput, query: MultimodalRerankInput,
docs: list[MultimodalRerankInput], docs: list[MultimodalRerankInput],
score_threshold: float | None = None, score_threshold: float | None = None,
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
content_text: str, content_text: str,
voice: str, voice: str,
) -> Generator[bytes, None, None]: ) -> Generator[bytes, None, None]:
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
language: str | None = None, language: str | None = None,
): ):
""" """
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
file: IO[bytes], file: IO[bytes],
) -> str: ) -> str:
""" """
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
text: str, text: str,
) -> bool: ) -> bool:
""" """