diff --git a/api/core/index/index_builder.py b/api/core/index/index_builder.py index baf16b0f3a..7f0486546e 100644 --- a/api/core/index/index_builder.py +++ b/api/core/index/index_builder.py @@ -33,8 +33,11 @@ class IndexBuilder: max_chunk_overlap=20 ) + provider = LLMBuilder.get_default_provider(tenant_id) + model_credentials = LLMBuilder.get_model_credentials( tenant_id=tenant_id, + model_provider=provider, model_name='text-embedding-ada-002' ) diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py index 9c4b0f9abd..cd8511cf28 100644 --- a/api/core/llm/llm_builder.py +++ b/api/core/llm/llm_builder.py @@ -5,11 +5,15 @@ from langchain.callbacks import CallbackManager from langchain.llms.fake import FakeListLLM from core.constant import llm_constant +from core.llm.error import ProviderTokenNotInitError +from core.llm.provider.base import BaseProvider from core.llm.provider.llm_provider_service import LLMProviderService +from core.llm.provider.openai_provider import OpenAIProvider from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_open_ai import StreamableOpenAI +from models.provider import ProviderType class LLMBuilder: @@ -34,7 +38,7 @@ class LLMBuilder: if model_name == 'fake': return FakeListLLM(responses=[]) - provider = current_app.config.get('DEFAULT_LLM_PROVIDER') + provider = cls.get_default_provider(tenant_id) mode = cls.get_mode_by_model(model_name) if mode == 'chat': @@ -50,7 +54,7 @@ class LLMBuilder: else: raise ValueError(f"model name {model_name} is not supported.") - model_credentials = cls.get_model_credentials(tenant_id, model_name) + model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) return llm_cls( model_name=model_name, @@ -96,7 +100,7 @@ class LLMBuilder: raise ValueError(f"model name {model_name} is not supported.") @classmethod - def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: + def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: """ Returns the API credentials for the given tenant_id and model_name, based on the model's provider. Raises an exception if the model_name is not found or if the provider is not found. @@ -108,7 +112,19 @@ class LLMBuilder: # raise Exception('model {} not found'.format(model_name)) # model_provider = llm_constant.models[model_name] - model_provider = current_app.config.get('DEFAULT_LLM_PROVIDER') provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) return provider_service.get_credentials(model_name) + + @classmethod + def get_default_provider(cls, tenant_id: str) -> str: + provider = BaseProvider.get_valid_provider(tenant_id) + if not provider: + raise ProviderTokenNotInitError() + + if provider.provider_type == ProviderType.SYSTEM.value: + provider_name = 'openai' + else: + provider_name = provider.provider_name + + return provider_name diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index 0377a9d8b9..2736b193bc 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -38,7 +38,7 @@ class AzureProvider(BaseProvider): """ config = self.get_provider_api_key(model_id=model_id) config['openai_api_type'] = 'azure' - config['deployment_name'] = model_id + config['deployment_name'] = model_id.replace('.', '') return config def get_provider_name(self): @@ -50,7 +50,6 @@ class AzureProvider(BaseProvider): """ try: config = self.get_provider_api_key() - config = json.loads(config) except: config = { 'openai_api_type': 'azure', diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py index 717a8298a7..2865489b95 100644 --- a/api/core/llm/provider/base.py +++ b/api/core/llm/provider/base.py @@ -43,23 +43,35 @@ class BaseProvider(ABC): Returns the Provider instance for the given tenant_id and provider_name. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. """ - providers = db.session.query(Provider).filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.get_provider_name().value - ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() + return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) + + @classmethod + def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: + """ + Returns the Provider instance for the given tenant_id and provider_name. + If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. + """ + query = db.session.query(Provider).filter( + Provider.tenant_id == tenant_id + ) + + if provider_name: + query = query.filter(Provider.provider_name == provider_name) + + providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() custom_provider = None system_provider = None for provider in providers: - if provider.provider_type == ProviderType.CUSTOM.value: + if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: custom_provider = provider - elif provider.provider_type == ProviderType.SYSTEM.value: + elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: system_provider = provider - if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: + if custom_provider: return custom_provider - elif system_provider and system_provider.is_valid: + elif system_provider: return system_provider else: return None