From 2b4030b7baa1cd517e352767c66478fae60aa4aa Mon Sep 17 00:00:00 2001 From: John Wang Date: Fri, 19 May 2023 12:04:56 +0800 Subject: [PATCH] feat: optimize langchain llm credential pass in --- api/core/llm/streamable_azure_chat_open_ai.py | 42 ++++++++++++++++- api/core/llm/streamable_azure_open_ai.py | 46 ++++++++++++++++++- api/core/llm/streamable_chat_open_ai.py | 42 ++++++++++++++++- api/core/llm/streamable_open_ai.py | 44 +++++++++++++++++- 4 files changed, 169 insertions(+), 5 deletions(-) diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py index 539ce92774..f3d514cf58 100644 --- a/api/core/llm/streamable_azure_chat_open_ai.py +++ b/api/core/llm/streamable_azure_chat_open_ai.py @@ -1,12 +1,50 @@ -import requests from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.chat_models import AzureChatOpenAI -from typing import Optional, List +from typing import Optional, List, Dict, Any + +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableAzureChatOpenAI(AzureChatOpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + **super()._default_params, + "engine": self.deployment_name, + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + } + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in a list of messages. diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py index be69b6a5a2..e383f8cf23 100644 --- a/api/core/llm/streamable_azure_open_ai.py +++ b/api/core/llm/streamable_azure_open_ai.py @@ -1,11 +1,55 @@ +import os + from langchain.llms import AzureOpenAI from langchain.schema import LLMResult -from typing import Optional, List +from typing import Optional, List, Dict, Mapping, Any + +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableAzureOpenAI(AzureOpenAI): + openai_api_type: str = "azure" + openai_api_version: str = "" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + values["client"] = openai.Completion + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") + return values + + @property + def _invocation_params(self) -> Dict[str, Any]: + return {**super()._invocation_params, **{ + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {**super()._identifying_params, **{ + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} @handle_llm_exceptions def generate( diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py index 59391e4ce0..582041ba09 100644 --- a/api/core/llm/streamable_chat_open_ai.py +++ b/api/core/llm/streamable_chat_open_ai.py @@ -1,12 +1,52 @@ +import os + from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.chat_models import ChatOpenAI -from typing import Optional, List +from typing import Optional, List, Dict, Any + +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableChatOpenAI(ChatOpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + **super()._default_params, + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + } + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in a list of messages. diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py index 94754af30e..9cf1b4c4bb 100644 --- a/api/core/llm/streamable_open_ai.py +++ b/api/core/llm/streamable_open_ai.py @@ -1,12 +1,54 @@ +import os + from langchain.schema import LLMResult -from typing import Optional, List +from typing import Optional, List, Dict, Any, Mapping from langchain import OpenAI +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableOpenAI(OpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + values["client"] = openai.Completion + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") + return values + + @property + def _invocation_params(self) -> Dict[str, Any]: + return {**super()._invocation_params, **{ + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {**super()._identifying_params, **{ + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @handle_llm_exceptions def generate( self, prompts: List[str], stop: Optional[List[str]] = None