mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: replace bare dict with dict[str, Any] in core provider services and misc modules (#35124)
This commit is contained in:
parent
2f682780fa
commit
eeebedcfe8
@ -1,5 +1,7 @@
|
|||||||
"""Configuration for InterSystems IRIS vector database."""
|
"""Configuration for InterSystems IRIS vector database."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, PositiveInt, model_validator
|
from pydantic import Field, PositiveInt, model_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Validate IRIS configuration values.
|
"""Validate IRIS configuration values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
credentials: dict
|
credentials: dict[str, Any]
|
||||||
credential_source_type: str | None = None
|
credential_source_type: str | None = None
|
||||||
credential_id: str | None = None
|
credential_id: str | None = None
|
||||||
|
|
||||||
|
|||||||
@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
|||||||
|
|
||||||
|
|
||||||
class ExternalDataToolFactory:
|
class ExternalDataToolFactory:
|
||||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
|
||||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
self.__extension_instance = extension_class(
|
self.__extension_instance = extension_class(
|
||||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Validate the incoming form config data.
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_load_balancing_manager(
|
def _get_load_balancing_manager(
|
||||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
|
||||||
) -> Optional["LBModelManager"]:
|
) -> Optional["LBModelManager"]:
|
||||||
"""
|
"""
|
||||||
Get load balancing model credentials
|
Get load balancing model credentials
|
||||||
|
|||||||
@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
app_mode: AppMode,
|
app_mode: AppMode,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
pre_prompt: str,
|
pre_prompt: str,
|
||||||
inputs: dict,
|
inputs: dict[str, Any],
|
||||||
query: str | None = None,
|
query: str | None = None,
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
histories: str | None = None,
|
histories: str | None = None,
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
# get prompt template
|
# get prompt template
|
||||||
prompt_template_config = self.get_prompt_template(
|
prompt_template_config = self.get_prompt_template(
|
||||||
app_mode=app_mode,
|
app_mode=app_mode,
|
||||||
@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
self,
|
self,
|
||||||
app_mode: AppMode,
|
app_mode: AppMode,
|
||||||
pre_prompt: str,
|
pre_prompt: str,
|
||||||
inputs: dict,
|
inputs: dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
files: Sequence["File"],
|
files: Sequence["File"],
|
||||||
@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
self,
|
self,
|
||||||
app_mode: AppMode,
|
app_mode: AppMode,
|
||||||
pre_prompt: str,
|
pre_prompt: str,
|
||||||
inputs: dict,
|
inputs: dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
files: Sequence["File"],
|
files: Sequence["File"],
|
||||||
|
|||||||
@ -856,7 +856,7 @@ class ProviderManager:
|
|||||||
secret_variables: list[str],
|
secret_variables: list[str],
|
||||||
cache_type: ProviderCredentialsCacheType,
|
cache_type: ProviderCredentialsCacheType,
|
||||||
is_provider: bool = False,
|
is_provider: bool = False,
|
||||||
) -> dict:
|
) -> dict[str, Any]:
|
||||||
"""Get and decrypt credentials with caching."""
|
"""Get and decrypt credentials with caching."""
|
||||||
credentials_cache = ProviderCredentialsCache(
|
credentials_cache = ProviderCredentialsCache(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|||||||
@ -174,8 +174,8 @@ class RetrievalService:
|
|||||||
cls,
|
cls,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
external_retrieval_model: dict | None = None,
|
external_retrieval_model: dict[str, Any] | None = None,
|
||||||
metadata_filtering_conditions: dict | None = None,
|
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||||
dataset = db.session.scalar(stmt)
|
dataset = db.session.scalar(stmt)
|
||||||
|
|||||||
@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
|
|
||||||
return embedding_results # type: ignore
|
return embedding_results # type: ignore
|
||||||
|
|
||||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||||
"""Embed multimodal documents."""
|
"""Embed multimodal documents."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
file_id = multimodel_document["file_id"]
|
file_id = multimodel_document["file_id"]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(ABC):
|
class Embeddings(ABC):
|
||||||
@ -20,7 +21,7 @@ class Embeddings(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||||
"""Embed multimodal query."""
|
"""Embed multimodal query."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
|
|||||||
return _case_routing
|
return _case_routing
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> dict:
|
def __getattr__(name: str) -> Any:
|
||||||
"""Lazy module-level access to routing tables."""
|
"""Lazy module-level access to routing tables."""
|
||||||
if name == "CASE_ROUTING":
|
if name == "CASE_ROUTING":
|
||||||
return _get_case_routing()
|
return _get_case_routing()
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import tempfile
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import clickzetta
|
import clickzetta
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict):
|
def validate_config(cls, values: dict[str, Any]):
|
||||||
"""Validate the configuration values.
|
"""Validate the configuration values.
|
||||||
|
|
||||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||||
@ -6,7 +8,7 @@ from models.model import AppMode, AppModelConfigDict
|
|||||||
|
|
||||||
class AppModelConfigService:
|
class AppModelConfigService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict:
|
||||||
match app_mode:
|
match app_mode:
|
||||||
case AppMode.CHAT:
|
case AppMode.CHAT:
|
||||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ class ApiKeyAuthService:
|
|||||||
return data_source_api_key_bindings
|
return data_source_api_key_bindings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_provider_auth(tenant_id: str, args: dict):
|
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
|
||||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||||
if auth_result:
|
if auth_result:
|
||||||
# Encrypt the api key
|
# Encrypt the api key
|
||||||
|
|||||||
@ -428,7 +428,7 @@ class ToolTransformService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_builtin_provider_to_credential_entity(
|
def convert_builtin_provider_to_credential_entity(
|
||||||
provider: BuiltinToolProvider, credentials: dict
|
provider: BuiltinToolProvider, credentials: dict[str, Any]
|
||||||
) -> ToolProviderCredentialApiEntity:
|
) -> ToolProviderCredentialApiEntity:
|
||||||
return ToolProviderCredentialApiEntity(
|
return ToolProviderCredentialApiEntity(
|
||||||
id=provider.id,
|
id=provider.id,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user