diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index abb817b244..d689008409 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, enterpris from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType +from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -72,6 +73,8 @@ class DatasetListApi(Resource): data = marshal(datasets, dataset_detail_fields) for item in data: + # convert embedding_model_provider to plugin standard format + item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 958a4b69e4..9beb26c870 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -173,7 +173,7 @@ class ModelProviderID(GenericProviderID): def __init__(self, value: str, is_hardcoded: bool = False) -> None: super().__init__(value, is_hardcoded) if self.organization == "langgenius" and self.provider_name == "google": - self.provider_name = "gemini" + self.plugin_name = "gemini" class ToolProviderID(GenericProviderID): @@ -181,7 +181,7 @@ class ToolProviderID(GenericProviderID): super().__init__(value, is_hardcoded) if self.organization == "langgenius": if self.provider_name in ["jina", "siliconflow"]: - self.provider_name = f"{self.provider_name}_tool" + self.plugin_name = f"{self.provider_name}_tool" class PluginDependency(BaseModel): diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index e328d59a8b..8e93252727 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.provider_entities import ( ProviderEntity, ) from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.entities.plugin import ModelProviderID from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -191,7 +192,7 @@ class ProviderManager: model_settings=model_settings, ) - provider_configurations[provider_name] = provider_configuration + provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration # Return the encapsulated object return provider_configurations