mirror of
https://github.com/langgenius/dify.git
synced 2026-04-18 04:16:28 +08:00
refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077)
This commit is contained in:
parent
7056d2ae99
commit
4e0273bb28
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
position: int
|
position: int
|
||||||
data_source_type: str
|
data_source_type: str
|
||||||
data_source_info: dict | None = None
|
data_source_info: dict[str, Any] | None = None
|
||||||
name: str
|
name: str
|
||||||
indexing_status: str
|
indexing_status: str
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import re
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import (
|
from graphon.model_runtime.entities.provider_entities import (
|
||||||
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||||
|
|
||||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get current credentials.
|
Get current credentials.
|
||||||
|
|
||||||
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
return session.execute(stmt).scalar_one_or_none()
|
return session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get a specific provider credential by ID.
|
Get a specific provider credential by ID.
|
||||||
:param credential_id: Credential ID
|
:param credential_id: Credential ID
|
||||||
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||||
return session.execute(stmt).scalar_one_or_none() is not None
|
return session.execute(stmt).scalar_one_or_none() is not None
|
||||||
|
|
||||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get provider credentials.
|
Get provider credentials.
|
||||||
|
|
||||||
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
else [],
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
def validate_provider_credentials(
|
||||||
|
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Validate custom credentials.
|
Validate custom credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
provider_names.append(model_provider_id.provider_name)
|
provider_names.append(model_provider_id.provider_name)
|
||||||
return provider_names
|
return provider_names
|
||||||
|
|
||||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||||
"""
|
"""
|
||||||
Add custom provider credentials.
|
Add custom provider credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
def update_provider_credential(
|
def update_provider_credential(
|
||||||
self,
|
self,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
credential_id: str,
|
credential_id: str,
|
||||||
credential_name: str | None,
|
credential_name: str | None,
|
||||||
):
|
):
|
||||||
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
def _get_specific_custom_model_credential(
|
def _get_specific_custom_model_credential(
|
||||||
self, model_type: ModelType, model: str, credential_id: str
|
self, model_type: ModelType, model: str, credential_id: str
|
||||||
) -> dict | None:
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get a specific provider credential by ID.
|
Get a specific provider credential by ID.
|
||||||
:param credential_id: Credential ID
|
:param credential_id: Credential ID
|
||||||
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||||
return session.execute(stmt).scalar_one_or_none() is not None
|
return session.execute(stmt).scalar_one_or_none() is not None
|
||||||
|
|
||||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
def get_custom_model_credential(
|
||||||
|
self, model_type: ModelType, model: str, credential_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get custom model credentials.
|
Get custom model credentials.
|
||||||
|
|
||||||
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
self,
|
self,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
credential_id: str = "",
|
credential_id: str = "",
|
||||||
session: Session | None = None,
|
session: Session | None = None,
|
||||||
):
|
):
|
||||||
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return _validate(new_session)
|
return _validate(new_session)
|
||||||
|
|
||||||
def create_custom_model_credential(
|
def create_custom_model_credential(
|
||||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a custom model credential.
|
Create a custom model credential.
|
||||||
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def update_custom_model_credential(
|
def update_custom_model_credential(
|
||||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
self,
|
||||||
|
model_type: ModelType,
|
||||||
|
model: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
credential_name: str | None,
|
||||||
|
credential_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update a custom model credential.
|
Update a custom model credential.
|
||||||
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
# Get model instance of LLM
|
# Get model instance of LLM
|
||||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||||
|
|
||||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
def get_model_schema(
|
||||||
|
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||||
|
) -> AIModelEntity | None:
|
||||||
"""
|
"""
|
||||||
Get model schema
|
Get model schema
|
||||||
"""
|
"""
|
||||||
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
return secret_input_form_variables
|
return secret_input_form_variables
|
||||||
|
|
||||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||||
"""
|
"""
|
||||||
Obfuscated credentials.
|
Obfuscated credentials.
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
|||||||
enabled: bool
|
enabled: bool
|
||||||
current_quota_type: ProviderQuotaType | None = None
|
current_quota_type: ProviderQuotaType | None = None
|
||||||
quota_configurations: list[QuotaConfiguration] = []
|
quota_configurations: list[QuotaConfiguration] = []
|
||||||
credentials: dict | None = None
|
credentials: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class CustomProviderConfiguration(BaseModel):
|
class CustomProviderConfiguration(BaseModel):
|
||||||
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
|||||||
Model class for provider custom configuration.
|
Model class for provider custom configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
credentials: dict
|
credentials: dict[str, Any]
|
||||||
current_credential_id: str | None = None
|
current_credential_id: str | None = None
|
||||||
current_credential_name: str | None = None
|
current_credential_name: str | None = None
|
||||||
available_credentials: list[CredentialConfiguration] = []
|
available_credentials: list[CredentialConfiguration] = []
|
||||||
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
|||||||
|
|
||||||
model: str
|
model: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
credentials: dict | None
|
credentials: dict[str, Any] | None
|
||||||
current_credential_id: str | None = None
|
current_credential_id: str | None = None
|
||||||
current_credential_name: str | None = None
|
current_credential_name: str | None = None
|
||||||
available_model_credentials: list[CredentialConfiguration] = []
|
available_model_credentials: list[CredentialConfiguration] = []
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
provider: str,
|
provider: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
) -> AIModelEntity | None:
|
) -> AIModelEntity | None:
|
||||||
"""
|
"""
|
||||||
Get model schema
|
Get model schema
|
||||||
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
provider: str,
|
provider: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
validate the credentials of the provider
|
validate the credentials of the provider
|
||||||
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict | None = None,
|
model_parameters: dict[str, Any] | None = None,
|
||||||
tools: list[PromptMessageTool] | None = None,
|
tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
provider: str,
|
provider: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool] | None = None,
|
tools: list[PromptMessageTool] | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
input_type: str,
|
input_type: str,
|
||||||
) -> EmbeddingResult:
|
) -> EmbeddingResult:
|
||||||
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
documents: list[dict],
|
documents: list[dict],
|
||||||
input_type: str,
|
input_type: str,
|
||||||
) -> EmbeddingResult:
|
) -> EmbeddingResult:
|
||||||
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
docs: list[str],
|
docs: list[str],
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
query: MultimodalRerankInput,
|
query: MultimodalRerankInput,
|
||||||
docs: list[MultimodalRerankInput],
|
docs: list[MultimodalRerankInput],
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
content_text: str,
|
content_text: str,
|
||||||
voice: str,
|
voice: str,
|
||||||
) -> Generator[bytes, None, None]:
|
) -> Generator[bytes, None, None]:
|
||||||
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
file: IO[bytes],
|
file: IO[bytes],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict[str, Any],
|
||||||
text: str,
|
text: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user