diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 62573ba2f5..d555f4d965 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,7 +4,7 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Literal, cast, overload +from typing import IO, Any, Literal, cast, overload, override from pydantic import ValidationError from redis import RedisError @@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime): self._provider_entities = None self._provider_entities_lock = Lock() + @override def fetch_model_providers(self) -> Sequence[ProviderEntity]: if self._provider_entities is not None: return self._provider_entities @@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime): return self._provider_entities + @override def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: provider_schema = self._get_provider_schema(provider) @@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime): mime_type = image_mime_types.get(extension, "image/png") return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + @override def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: plugin_id, provider_name = self._split_provider(provider) self.client.validate_provider_credentials( @@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def validate_model_credentials( self, *, @@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def get_model_schema( self, *, @@ -294,6 +299,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @override def invoke_llm( self, *, @@ -357,6 +363,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @override def invoke_llm_with_structured_output( self, *, @@ -396,6 +403,7 @@ class PluginModelRuntime(ModelRuntime): stream=stream, ) + @override def get_llm_num_tokens( self, *, @@ -422,6 +430,7 @@ class PluginModelRuntime(ModelRuntime): tools=list(tools) if tools else None, ) + @override def invoke_text_embedding( self, *, @@ -443,6 +452,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def invoke_multimodal_embedding( self, *, @@ -464,6 +474,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def get_text_embedding_num_tokens( self, *, @@ -483,6 +494,7 @@ class PluginModelRuntime(ModelRuntime): texts=texts, ) + @override def invoke_rerank( self, *, @@ -508,6 +520,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_multimodal_rerank( self, *, @@ -533,6 +546,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_tts( self, *, @@ -554,6 +568,7 @@ class PluginModelRuntime(ModelRuntime): voice=voice, ) + @override def get_tts_model_voices( self, *, @@ -573,6 +588,7 @@ class PluginModelRuntime(ModelRuntime): language=language, ) + @override def invoke_speech_to_text( self, *, @@ -592,6 +608,7 @@ class PluginModelRuntime(ModelRuntime): file=file, ) + @override def invoke_moderation( self, *,