mirror of https://github.com/langgenius/dify.git
feat: optimize langchain llm credential pass in
This commit is contained in:
parent
ef658f24b9
commit
2b4030b7ba
|
|
@ -1,12 +1,50 @@
|
|||
import requests
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,55 @@
|
|||
import os
|
||||
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Mapping, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,52 @@
|
|||
import os
|
||||
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,54 @@
|
|||
import os
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any, Mapping
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue