diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 23596558db..b311f069a8 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -20,6 +20,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.manager.asset import PluginAssetManager from core.plugin.manager.model import PluginModelManager @@ -112,6 +113,9 @@ class ModelProviderFactory: :param provider: provider name :return: provider schema """ + if "/" not in provider: + provider = str(ModelProviderID(provider)) + # fetch plugin model providers plugin_model_provider_entities = self.get_plugin_model_providers() @@ -363,4 +367,4 @@ class ModelProviderFactory: plugin_id = "/".join(provider.split("/")[:-1]) provider_name = provider.split("/")[-1] - return plugin_id, provider_name + return str(plugin_id), provider_name diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index ee65e86826..aa78eb919c 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -169,6 +169,13 @@ class GenericProviderID: return f"{self.organization}/{self.plugin_name}" +class ModelProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) + if self.organization == "langgenius" and self.provider_name == "google": + self.provider_name = "gemini" + + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value