diff --git a/api/config.py b/api/config.py index 1e6000c8ae..f81527da61 100644 --- a/api/config.py +++ b/api/config.py @@ -47,6 +47,7 @@ DEFAULTS = { 'PDF_PREVIEW': 'True', 'LOG_LEVEL': 'INFO', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', + 'DEFAULT_LLM_PROVIDER': 'openai' } @@ -181,6 +182,10 @@ class Config: # You could disable it for compatibility with certain OpenAPI providers self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') + # For temp use only + # set default LLM provider, default is 'openai', support `azure_openai` + self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') + class CloudEditionConfig(Config): def __init__(self): diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py index 0938397423..0f7cb252e2 100644 --- a/api/core/embedding/openai_embedding.py +++ b/api/core/embedding/openai_embedding.py @@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embedding( - text: str, - engine: Optional[str] = None, - openai_api_key: Optional[str] = None, + text: str, + engine: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs ) -> List[float]: """Get embedding. @@ -25,11 +26,12 @@ def get_embedding( """ text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] + return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: +async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ + float]: """Asynchronously get embedding. NOTE: Copied from OpenAI's embedding utils: @@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ + return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ "embedding" ] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( - list_of_text: List[str], - engine: Optional[str] = None, - openai_api_key: Optional[str] = None + list_of_text: List[str], + engine: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs ) -> List[List[float]]: """Get embeddings. @@ -67,14 +70,14 @@ def get_embeddings( # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data + data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embeddings( - list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None + list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs ) -> List[List[float]]: """Asynchronously get embeddings. @@ -90,7 +93,7 @@ async def aget_embeddings( # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @@ -98,19 +101,30 @@ async def aget_embeddings( class OpenAIEmbedding(BaseEmbedding): def __init__( - self, - mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, - model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, - deployment_name: Optional[str] = None, - openai_api_key: Optional[str] = None, - **kwargs: Any, + self, + mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, + model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, + deployment_name: Optional[str] = None, + openai_api_key: Optional[str] = None, + **kwargs: Any, ) -> None: """Init params.""" - super().__init__(**kwargs) + new_kwargs = {} + + if 'embed_batch_size' in kwargs: + new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] + + if 'tokenizer' in kwargs: + new_kwargs['tokenizer'] = kwargs['tokenizer'] + + super().__init__(**new_kwargs) self.mode = OpenAIEmbeddingMode(mode) self.model = OpenAIEmbeddingModelType(model) self.deployment_name = deployment_name self.openai_api_key = openai_api_key + self.openai_api_type = kwargs.get('openai_api_type') + self.openai_api_version = kwargs.get('openai_api_version') + self.openai_api_base = kwargs.get('openai_api_base') @handle_llm_exceptions def _get_query_embedding(self, query: str) -> List[float]: @@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _QUERY_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _QUERY_MODE_MODEL_DICT[key] - return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) + return get_embedding(query, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" @@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + return get_embedding(text, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding.""" @@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings. @@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) return embeddings async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: @@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) return embeddings diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py index 4355593c5d..9c4b0f9abd 100644 --- a/api/core/llm/llm_builder.py +++ b/api/core/llm/llm_builder.py @@ -1,10 +1,13 @@ from typing import Union, Optional +from flask import current_app from langchain.callbacks import CallbackManager from langchain.llms.fake import FakeListLLM from core.constant import llm_constant from core.llm.provider.llm_provider_service import LLMProviderService +from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI +from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_open_ai import StreamableOpenAI @@ -31,12 +34,19 @@ class LLMBuilder: if model_name == 'fake': return FakeListLLM(responses=[]) + provider = current_app.config.get('DEFAULT_LLM_PROVIDER') + mode = cls.get_mode_by_model(model_name) if mode == 'chat': - # llm_cls = StreamableAzureChatOpenAI - llm_cls = StreamableChatOpenAI + if provider == 'openai': + llm_cls = StreamableChatOpenAI + else: + llm_cls = StreamableAzureChatOpenAI elif mode == 'completion': - llm_cls = StreamableOpenAI + if provider == 'openai': + llm_cls = StreamableOpenAI + else: + llm_cls = StreamableAzureOpenAI else: raise ValueError(f"model name {model_name} is not supported.") @@ -93,11 +103,12 @@ class LLMBuilder: """ if not model_name: raise Exception('model name not found') + # + # if model_name not in llm_constant.models: + # raise Exception('model {} not found'.format(model_name)) - if model_name not in llm_constant.models: - raise Exception('model {} not found'.format(model_name)) - - model_provider = llm_constant.models[model_name] + # model_provider = llm_constant.models[model_name] + model_provider = current_app.config.get('DEFAULT_LLM_PROVIDER') provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) return provider_service.get_credentials(model_name) diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index e0ba0d0734..0377a9d8b9 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -36,8 +36,7 @@ class AzureProvider(BaseProvider): """ Returns the API credentials for Azure OpenAI as a dictionary. """ - encrypted_config = self.get_provider_api_key(model_id=model_id) - config = json.loads(encrypted_config) + config = self.get_provider_api_key(model_id=model_id) config['openai_api_type'] = 'azure' config['deployment_name'] = model_id return config diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py index 89343ff62a..717a8298a7 100644 --- a/api/core/llm/provider/base.py +++ b/api/core/llm/provider/base.py @@ -14,7 +14,7 @@ class BaseProvider(ABC): def __init__(self, tenant_id: str): self.tenant_id = tenant_id - def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: + def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: """ Returns the decrypted API key for the given tenant_id and provider_name. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py new file mode 100644 index 0000000000..be69b6a5a2 --- /dev/null +++ b/api/core/llm/streamable_azure_open_ai.py @@ -0,0 +1,20 @@ +from langchain.llms import AzureOpenAI +from langchain.schema import LLMResult +from typing import Optional, List + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +class StreamableAzureOpenAI(AzureOpenAI): + + @handle_llm_exceptions + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return super().generate(prompts, stop) + + @handle_llm_exceptions_async + async def agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return await super().agenerate(prompts, stop) diff --git a/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx b/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx index 71236120e5..5681fd2204 100644 --- a/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx +++ b/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx @@ -20,7 +20,7 @@ const AzureProvider = ({ const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) const handleFocus = () => { if (token === provider.token) { - token.azure_api_key = '' + token.openai_api_key = '' setToken({...token}) onTokenChange({...token}) } @@ -35,31 +35,17 @@ const AzureProvider = ({