mirror of https://github.com/langgenius/dify.git
feat: support azure openai temporary, must create deployment id same as openai model name, eg. gpt-3.5-turbo / text-embedding-ada-002 / ...
This commit is contained in:
parent
3b3c604eb5
commit
cbe0f6f3ad
|
|
@ -47,6 +47,7 @@ DEFAULTS = {
|
|||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai'
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -181,6 +182,10 @@ class Config:
|
|||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
|
||||
|
||||
# For temp use only
|
||||
# set default LLM provider, default is 'openai', support `azure_openai`
|
||||
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
|
|||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embedding(
|
||||
text: str,
|
||||
engine: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
text: str,
|
||||
engine: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
|
|
@ -25,11 +26,12 @@ def get_embedding(
|
|||
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
|
||||
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
|
||||
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
|
||||
float]:
|
||||
"""Asynchronously get embedding.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
|
|
@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
|
|||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
|
||||
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
|
||||
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
|
||||
"embedding"
|
||||
]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embeddings(
|
||||
list_of_text: List[str],
|
||||
engine: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None
|
||||
list_of_text: List[str],
|
||||
engine: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""Get embeddings.
|
||||
|
||||
|
|
@ -67,14 +70,14 @@ def get_embeddings(
|
|||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
|
||||
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embeddings(
|
||||
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
|
||||
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||
) -> List[List[float]]:
|
||||
"""Asynchronously get embeddings.
|
||||
|
||||
|
|
@ -90,7 +93,7 @@ async def aget_embeddings(
|
|||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
|
||||
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
|
@ -98,19 +101,30 @@ async def aget_embeddings(
|
|||
class OpenAIEmbedding(BaseEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
||||
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
||||
deployment_name: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
||||
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
||||
deployment_name: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(**kwargs)
|
||||
new_kwargs = {}
|
||||
|
||||
if 'embed_batch_size' in kwargs:
|
||||
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
|
||||
|
||||
if 'tokenizer' in kwargs:
|
||||
new_kwargs['tokenizer'] = kwargs['tokenizer']
|
||||
|
||||
super().__init__(**new_kwargs)
|
||||
self.mode = OpenAIEmbeddingMode(mode)
|
||||
self.model = OpenAIEmbeddingModelType(model)
|
||||
self.deployment_name = deployment_name
|
||||
self.openai_api_key = openai_api_key
|
||||
self.openai_api_type = kwargs.get('openai_api_type')
|
||||
self.openai_api_version = kwargs.get('openai_api_version')
|
||||
self.openai_api_base = kwargs.get('openai_api_base')
|
||||
|
||||
@handle_llm_exceptions
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
|
|
@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||
if key not in _QUERY_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _QUERY_MODE_MODEL_DICT[key]
|
||||
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
|
||||
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
"""Get text embedding."""
|
||||
|
|
@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
||||
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
"""Asynchronously get text embedding."""
|
||||
|
|
@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
||||
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get text embeddings.
|
||||
|
|
@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
||||
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
return embeddings
|
||||
|
||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
|
|
@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
||||
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
return embeddings
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Union, Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
|
@ -31,12 +34,19 @@ class LLMBuilder:
|
|||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
provider = current_app.config.get('DEFAULT_LLM_PROVIDER')
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
# llm_cls = StreamableAzureChatOpenAI
|
||||
llm_cls = StreamableChatOpenAI
|
||||
if provider == 'openai':
|
||||
llm_cls = StreamableChatOpenAI
|
||||
else:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif mode == 'completion':
|
||||
llm_cls = StreamableOpenAI
|
||||
if provider == 'openai':
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
|
|
@ -93,11 +103,12 @@ class LLMBuilder:
|
|||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
#
|
||||
# if model_name not in llm_constant.models:
|
||||
# raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
if model_name not in llm_constant.models:
|
||||
raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
model_provider = llm_constant.models[model_name]
|
||||
# model_provider = llm_constant.models[model_name]
|
||||
model_provider = current_app.config.get('DEFAULT_LLM_PROVIDER')
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
|
|
|||
|
|
@ -36,8 +36,7 @@ class AzureProvider(BaseProvider):
|
|||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
encrypted_config = self.get_provider_api_key(model_id=model_id)
|
||||
config = json.loads(encrypted_config)
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class BaseProvider(ABC):
|
|||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
|
|
@ -20,7 +20,7 @@ const AzureProvider = ({
|
|||
const [token, setToken] = useState(provider.token as ProviderAzureToken || {})
|
||||
const handleFocus = () => {
|
||||
if (token === provider.token) {
|
||||
token.azure_api_key = ''
|
||||
token.openai_api_key = ''
|
||||
setToken({...token})
|
||||
onTokenChange({...token})
|
||||
}
|
||||
|
|
@ -35,31 +35,17 @@ const AzureProvider = ({
|
|||
<div className='px-4 py-3'>
|
||||
<ProviderInput
|
||||
className='mb-4'
|
||||
name={t('common.provider.azure.resourceName')}
|
||||
placeholder={t('common.provider.azure.resourceNamePlaceholder')}
|
||||
value={token.azure_api_base}
|
||||
onChange={(v) => handleChange('azure_api_base', v)}
|
||||
/>
|
||||
<ProviderInput
|
||||
className='mb-4'
|
||||
name={t('common.provider.azure.deploymentId')}
|
||||
placeholder={t('common.provider.azure.deploymentIdPlaceholder')}
|
||||
value={token.azure_api_type}
|
||||
onChange={v => handleChange('azure_api_type', v)}
|
||||
/>
|
||||
<ProviderInput
|
||||
className='mb-4'
|
||||
name={t('common.provider.azure.apiVersion')}
|
||||
placeholder={t('common.provider.azure.apiVersionPlaceholder')}
|
||||
value={token.azure_api_version}
|
||||
onChange={v => handleChange('azure_api_version', v)}
|
||||
name={t('common.provider.azure.apiBase')}
|
||||
placeholder={t('common.provider.azure.apiBasePlaceholder')}
|
||||
value={token.openai_api_base}
|
||||
onChange={(v) => handleChange('openai_api_base', v)}
|
||||
/>
|
||||
<ProviderValidateTokenInput
|
||||
className='mb-4'
|
||||
name={t('common.provider.azure.apiKey')}
|
||||
placeholder={t('common.provider.azure.apiKeyPlaceholder')}
|
||||
value={token.azure_api_key}
|
||||
onChange={v => handleChange('azure_api_key', v)}
|
||||
value={token.openai_api_key}
|
||||
onChange={v => handleChange('openai_api_key', v)}
|
||||
onFocus={handleFocus}
|
||||
onValidatedStatus={onValidatedStatus}
|
||||
providerName={provider.provider_name}
|
||||
|
|
@ -72,4 +58,4 @@ const AzureProvider = ({
|
|||
)
|
||||
}
|
||||
|
||||
export default AzureProvider
|
||||
export default AzureProvider
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ const ProviderItem = ({
|
|||
const { notify } = useContext(ToastContext)
|
||||
const [token, setToken] = useState<ProviderAzureToken | string>(
|
||||
provider.provider_name === 'azure_openai'
|
||||
? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' }
|
||||
? { openai_api_base: '', openai_api_key: '' }
|
||||
: ''
|
||||
)
|
||||
const id = `${provider.provider_name}-${provider.provider_type}`
|
||||
const isOpen = id === activeId
|
||||
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token
|
||||
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token
|
||||
const comingSoon = false
|
||||
const isValid = provider.is_valid
|
||||
|
||||
|
|
@ -135,4 +135,4 @@ const ProviderItem = ({
|
|||
)
|
||||
}
|
||||
|
||||
export default ProviderItem
|
||||
export default ProviderItem
|
||||
|
|
|
|||
|
|
@ -148,12 +148,8 @@ const translation = {
|
|||
editKey: 'Edit',
|
||||
invalidApiKey: 'Invalid API key',
|
||||
azure: {
|
||||
resourceName: 'Resource Name',
|
||||
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
|
||||
deploymentId: 'Deployment ID',
|
||||
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
|
||||
apiVersion: 'API Version',
|
||||
apiVersionPlaceholder: 'The API version to use for this operation.',
|
||||
apiBase: 'API Base',
|
||||
apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.',
|
||||
apiKey: 'API Key',
|
||||
apiKeyPlaceholder: 'Enter your API key here',
|
||||
helpTip: 'Learn Azure OpenAI Service',
|
||||
|
|
|
|||
|
|
@ -149,14 +149,10 @@ const translation = {
|
|||
editKey: '编辑',
|
||||
invalidApiKey: '无效的 API 密钥',
|
||||
azure: {
|
||||
resourceName: 'Resource Name',
|
||||
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
|
||||
deploymentId: 'Deployment ID',
|
||||
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
|
||||
apiVersion: 'API Version',
|
||||
apiVersionPlaceholder: 'The API version to use for this operation.',
|
||||
apiBase: 'API Base',
|
||||
apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址',
|
||||
apiKey: 'API Key',
|
||||
apiKeyPlaceholder: 'Enter your API key here',
|
||||
apiKeyPlaceholder: '输入你的 API 密钥',
|
||||
helpTip: '了解 Azure OpenAI Service',
|
||||
},
|
||||
openaiHosted: {
|
||||
|
|
|
|||
|
|
@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
|
|||
}
|
||||
|
||||
export type ProviderAzureToken = {
|
||||
azure_api_base: string
|
||||
azure_api_key: string
|
||||
azure_api_type: string
|
||||
azure_api_version: string
|
||||
openai_api_base: string
|
||||
openai_api_key: string
|
||||
}
|
||||
export type Provider = {
|
||||
provider_name: string
|
||||
|
|
|
|||
Loading…
Reference in New Issue