mirror of https://github.com/langgenius/dify.git
feat: remove unused codes
This commit is contained in:
parent
196bfeaaf4
commit
8236373498
|
|
@ -67,6 +67,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
|||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from flask import request
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -145,7 +144,19 @@ class DatasetDocumentListApi(Resource):
|
|||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch = string_to_bool(request.args.get("fetch", default="false"))
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
|
||||
fetch = True
|
||||
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
|
||||
fetch = False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
|
||||
f"(case insensitive)."
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception) as e:
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ class ModelType(Enum):
|
|||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
|
|
@ -37,8 +36,6 @@ class ModelType(Enum):
|
|||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
return cls.TTS
|
||||
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
|
|
@ -62,8 +59,6 @@ class ModelType(Enum):
|
|||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return "moderation"
|
||||
elif self == self.TEXT2IMG:
|
||||
return "text2img"
|
||||
else:
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,26 +1,18 @@
|
|||
import decimal
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
class AIModel:
|
||||
|
|
@ -117,93 +109,7 @@ class AIModel:
|
|||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def predefined_models(self) -> list[AIModelEntity]:
|
||||
"""
|
||||
Get all predefined models for given provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.model_schemas:
|
||||
return self.model_schemas
|
||||
|
||||
model_schemas = []
|
||||
|
||||
# get module name
|
||||
model_type = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
# get provider name
|
||||
provider_name = self.__class__.__module__.split(".")[-3]
|
||||
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
# get parent path of the current path
|
||||
provider_model_type_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
|
||||
)
|
||||
|
||||
# get all yaml files path under provider_model_type_path that do not start with __
|
||||
model_schema_yaml_paths = [
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith("__")
|
||||
and not model_schema_yaml.startswith("_")
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith(".yaml")
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_map = get_position_map(provider_model_type_path)
|
||||
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get("parameter_rules", []):
|
||||
if "use_template" in parameter_rule:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"])
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
copy_default_parameter_rule = default_parameter_rule.copy()
|
||||
copy_default_parameter_rule.update(parameter_rule)
|
||||
parameter_rule = copy_default_parameter_rule
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if "label" not in parameter_rule:
|
||||
parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]}
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
yaml_data["parameter_rules"] = new_parameter_rules
|
||||
|
||||
if "label" not in yaml_data:
|
||||
yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]}
|
||||
|
||||
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
model_schema = AIModelEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
|
||||
raise Exception(
|
||||
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# cache model schema
|
||||
model_schemas.append(model_schema)
|
||||
|
||||
# resort model schemas by position
|
||||
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
|
||||
|
||||
# cache model schemas
|
||||
self.model_schemas = model_schemas
|
||||
|
||||
return model_schemas
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
|
|
@ -211,117 +117,13 @@ class AIModel:
|
|||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
# get predefined models (predefined_models)
|
||||
models = self.predefined_models()
|
||||
|
||||
model_map = {model.model: model for model in models}
|
||||
if model in model_map:
|
||||
return model_map[model]
|
||||
|
||||
if credentials:
|
||||
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
||||
if model_schema:
|
||||
return model_schema
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(
|
||||
self, model: str, credentials: Mapping
|
||||
) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
if parameter_rule.use_template:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule["max"]
|
||||
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule["min"]
|
||||
if not parameter_rule.default and "default" in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule["default"]
|
||||
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule["precision"]
|
||||
if not parameter_rule.required and "required" in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule["required"]
|
||||
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule["help"]["en_US"],
|
||||
)
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.en_US
|
||||
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.zh_Hans
|
||||
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
schema.parameter_rules = new_parameter_rules
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
|
||||
"""
|
||||
Get default parameter rule for given name
|
||||
|
||||
:param name: parameter name
|
||||
:return: parameter rule
|
||||
"""
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
|
||||
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages by gpt2
|
||||
Some provider models do not provide an interface for obtaining the number of tokens.
|
||||
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
||||
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
||||
|
||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||
:return: number of tokens
|
||||
"""
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
provider_schema: Optional[ProviderEntity] = None
|
||||
model_instance_map: dict[str, AIModel] = {}
|
||||
|
||||
@abstractmethod
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider_schema(self) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
|
||||
:return: provider schema
|
||||
"""
|
||||
if self.provider_schema:
|
||||
return self.provider_schema
|
||||
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
# get the path of the model_provider classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
|
||||
yaml_data = load_yaml_file(yaml_path)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
provider_schema = ProviderEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
|
||||
|
||||
# cache schema
|
||||
self.provider_schema = provider_schema
|
||||
|
||||
return provider_schema
|
||||
|
||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
:param model_type: model type defined in `ModelType`
|
||||
:return: list of models
|
||||
"""
|
||||
provider_schema = self.get_provider_schema()
|
||||
if model_type not in provider_schema.supported_model_types:
|
||||
return []
|
||||
|
||||
# get model instance of the model type
|
||||
model_instance = self.get_model_instance(model_type)
|
||||
|
||||
# get predefined models (predefined_models)
|
||||
models = model_instance.predefined_models()
|
||||
|
||||
# return models
|
||||
return models
|
||||
|
||||
def get_model_instance(self, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get model instance
|
||||
|
||||
:param model_type: model type defined in `ModelType`
|
||||
:return:
|
||||
"""
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
||||
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
||||
|
||||
# get the path of the model type classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
model_type_name = model_type.value.replace("-", "_")
|
||||
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
|
||||
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
|
||||
|
||||
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
||||
raise Exception(f"Invalid model type {model_type} for provider {provider_name}")
|
||||
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
|
||||
mod = import_module_from_source(
|
||||
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||
)
|
||||
model_class = next(
|
||||
filter(
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
get_subclasses_from_module(mod, AIModel),
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not model_class:
|
||||
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
|
||||
|
||||
model_instance_map = model_class()
|
||||
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
||||
|
||||
return model_instance_map
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
|
|
@ -31,19 +31,15 @@ class ModerationModel(AIModel):
|
|||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, text, user)
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
|
|
@ -36,34 +35,19 @@ class RerankModel(AIModel):
|
|||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
|
|
@ -20,7 +19,7 @@ class Speech2TextModel(AIModel):
|
|||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
Invoke speech to text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
|
|
@ -29,31 +28,15 @@ class Speech2TextModel(AIModel):
|
|||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model, credentials, file, user)
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_demo_file_path(self) -> str:
|
||||
"""
|
||||
Get demo file for given model
|
||||
|
||||
:return: demo file
|
||||
"""
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Construct the path to the audio file
|
||||
return os.path.join(current_dir, "audio.mp3")
|
||||
|
|
|
|||
|
|
@ -1,54 +0,0 @@
|
|||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Text2ImageModel(AIModel):
|
||||
"""
|
||||
Model class for text2img model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT2IMG
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
||||
) -> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model, credentials, prompt, model_parameters, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
||||
) -> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
|
@ -39,34 +37,21 @@ class TextEmbeddingModel(AIModel):
|
|||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, texts, user, input_type)
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type.value,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
|
@ -82,7 +67,6 @@ class TextEmbeddingModel(AIModel):
|
|||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
|
|
|
|||
|
|
@ -1,34 +1,9 @@
|
|||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
import tiktoken
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text, verbose=False)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
encoding = tiktoken.encoding_for_model("gpt2")
|
||||
tiktoken_vec = encoding.encode(text)
|
||||
return len(tiktoken_vec)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,36 +36,21 @@ class TTSModel(AIModel):
|
|||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
user=user,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]:
|
||||
"""
|
||||
Get voice for given tts model voices
|
||||
|
||||
|
|
@ -75,83 +59,13 @@ class TTSModel(AIModel):
|
|||
:param credentials: model credentials
|
||||
:return: voices lists
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]}
|
||||
for d in voices
|
||||
if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
"""
|
||||
Get voice for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio max workers for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ""
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ""
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != "":
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
plugin_model_manager = PluginModelManager()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,191 +0,0 @@
|
|||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel
|
||||
|
||||
|
||||
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
base_model_name = credentials["base_model_name"]
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs["encoding_format"] = "base64"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(base_model_name)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
token = enc.encode(text)
|
||||
for j in range(0, len(token), context_size):
|
||||
tokens += [token[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = enc.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if "openai_api_base" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
||||
|
||||
if "openai_api_key" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _embedding_invoke(
|
||||
model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||
) -> tuple[list[list[float]], int]:
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
model=model,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64":
|
||||
# decode base64 embedding
|
||||
return (
|
||||
[list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
||||
response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in EMBEDDING_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
import time
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for BaiChuan text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if model != "baichuan-text-embedding":
|
||||
raise ValueError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
# split into chunks of batch size 16
|
||||
chunks = []
|
||||
for i in range(0, len(texts), 16):
|
||||
chunks.append(texts[i : i + 16])
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for chunk in chunks:
|
||||
# embedding chunk
|
||||
chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user)
|
||||
|
||||
embeddings.extend(chunk_embeddings)
|
||||
token_usage += chunk_usage
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def embedding(
|
||||
self, model: str, api_key, texts: list[str], user: Optional[str] = None
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Embed given texts
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
url = self.api_base
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": "Baichuan-Text-Embedding", "input": texts}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp["error"]["code"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and "rate" in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif err and "internal" in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == "api_key_empty":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
return [data["embedding"] for data in embeddings], usage["total_tokens"]
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use BaichuanTokenizer to get num tokens
|
||||
num_tokens += BaichuanTokenizer._get_num_tokens(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvalidAPIKeyError:
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import cohere
|
||||
import numpy as np
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokenize_response = self._tokenize(model=model, credentials=credentials, text=text)
|
||||
|
||||
for j in range(0, len(tokenize_response), context_size):
|
||||
tokens += [tokenize_response[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, credentials=credentials, texts=[" "]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
full_text = " ".join(texts)
|
||||
|
||||
try:
|
||||
response = self._tokenize(model=model, credentials=credentials, text=full_text)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
return len(response)
|
||||
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
|
||||
"""
|
||||
Tokenize text
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to tokenize
|
||||
:return:
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0))
|
||||
|
||||
return response.token_strings
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
# call embedding model
|
||||
response = client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type="search_document" if len(texts) > 1 else "search_query",
|
||||
request_options=RequestOptions(max_retries=1),
|
||||
)
|
||||
|
||||
return response.embeddings, int(response.meta.billed_units.input_tokens)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
|
||||
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
|
||||
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
],
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
provider: fireworks
|
||||
label:
|
||||
zh_Hans: Fireworks AI
|
||||
en_US: Fireworks AI
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FCFDFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from Fireworks AI
|
||||
zh_Hans: 从 Fireworks AI 获取 API Key
|
||||
url:
|
||||
en_US: https://fireworks.ai/account/api-keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: fireworks_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
model: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||
label:
|
||||
zh_Hans: Llama 3.2 11B Vision Instruct
|
||||
en_US: Llama 3.2 11B Vision Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
model: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||
label:
|
||||
zh_Hans: Llama 3.2 1B Instruct
|
||||
en_US: Llama 3.2 1B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
model: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
label:
|
||||
zh_Hans: Llama 3.2 3B Instruct
|
||||
en_US: Llama 3.2 3B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
model: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||
label:
|
||||
zh_Hans: Llama 3.2 90B Vision Instruct
|
||||
en_US: Llama 3.2 90B Vision Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
model: WhereIsAI/UAE-Large-V1
|
||||
label:
|
||||
zh_Hans: UAE-Large-V1
|
||||
en_US: UAE-Large-V1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 1
|
||||
pricing:
|
||||
input: '0.008'
|
||||
unit: '0.000001'
|
||||
currency: 'USD'
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
model: thenlper/gte-base
|
||||
label:
|
||||
zh_Hans: GTE-base
|
||||
en_US: GTE-base
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 1
|
||||
pricing:
|
||||
input: '0.008'
|
||||
unit: '0.000001'
|
||||
currency: 'USD'
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
model: thenlper/gte-large
|
||||
label:
|
||||
zh_Hans: GTE-large
|
||||
en_US: GTE-large
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 1
|
||||
pricing:
|
||||
input: '0.008'
|
||||
unit: '0.000001'
|
||||
currency: 'USD'
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
model: nomic-ai/nomic-embed-text-v1.5
|
||||
label:
|
||||
zh_Hans: nomic-embed-text-v1.5
|
||||
en_US: nomic-embed-text-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.008'
|
||||
unit: '0.000001'
|
||||
currency: 'USD'
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
model: nomic-ai/nomic-embed-text-v1
|
||||
label:
|
||||
zh_Hans: nomic-embed-text-v1
|
||||
en_US: nomic-embed-text-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.008'
|
||||
unit: '0.000001'
|
||||
currency: 'USD'
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||
|
||||
|
||||
class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Fireworks text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs["encoding_format"] = "float"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=inputs[i : i + max_chunks],
|
||||
extra_model_kwargs=extra_model_kwargs,
|
||||
)
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(
|
||||
self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to embed
|
||||
:param extra_model_kwargs: extra model kwargs
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs)
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT
|
||||
)
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
provider: fishaudio
|
||||
label:
|
||||
en_US: Fish Audio
|
||||
description:
|
||||
en_US: Models provided by Fish Audio, currently only support TTS.
|
||||
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
|
||||
icon_small:
|
||||
en_US: fishaudio_s_en.svg
|
||||
icon_large:
|
||||
en_US: fishaudio_l_en.svg
|
||||
background: "#E5E7EB"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Fish Audio
|
||||
zh_Hans: 从 Fish Audio 获取你的 API Key
|
||||
url:
|
||||
en_US: https://fish.audio/go-api/
|
||||
supported_model_types:
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.fish.audio
|
||||
placeholder:
|
||||
en_US: Enter your API URL
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
- variable: use_public_models
|
||||
label:
|
||||
en_US: Use Public Models
|
||||
type: select
|
||||
required: false
|
||||
default: "false"
|
||||
placeholder:
|
||||
en_US: Toggle to use public models
|
||||
zh_Hans: 切换以使用公共模型
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: Allow Public Models
|
||||
zh_Hans: 使用公共模型
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: Private Models Only
|
||||
zh_Hans: 仅使用私有模型
|
||||
- variable: latency
|
||||
label:
|
||||
en_US: Latency
|
||||
type: select
|
||||
required: false
|
||||
default: "normal"
|
||||
placeholder:
|
||||
en_US: Toggle to choice latency
|
||||
zh_Hans: 切换以调整延迟
|
||||
options:
|
||||
- value: "balanced"
|
||||
label:
|
||||
en_US: Low (may affect quality)
|
||||
zh_Hans: 低延迟 (可能降低质量)
|
||||
- value: "normal"
|
||||
label:
|
||||
en_US: Normal
|
||||
zh_Hans: 标准
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-flash-001
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-flash-002
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 002
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-flash-8b-exp-0924
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 8B 0924
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-flash
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-pro-001
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-pro-002
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 002
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
model: gemini-1.5-pro
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
model: llama-3.2-11b-text-preview
|
||||
label:
|
||||
zh_Hans: Llama 3.2 11B Text (Preview)
|
||||
en_US: Llama 3.2 11B Text (Preview)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
model: llama-3.2-1b-preview
|
||||
label:
|
||||
zh_Hans: Llama 3.2 1B Text (Preview)
|
||||
en_US: Llama 3.2 1B Text (Preview)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
model: llama-3.2-3b-preview
|
||||
label:
|
||||
zh_Hans: Llama 3.2 3B Text (Preview)
|
||||
en_US: Llama 3.2 3B Text (Preview)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
model: llama-3.2-90b-text-preview
|
||||
label:
|
||||
zh_Hans: Llama 3.2 90B Text (Preview)
|
||||
en_US: Llama 3.2 90B Text (Preview)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,189 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
|
||||
|
||||
HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
|
||||
|
||||
|
||||
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
|
||||
execute_model = model
|
||||
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
execute_model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
output = client.post(
|
||||
json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model
|
||||
)
|
||||
|
||||
embeddings = json.loads(output.decode())
|
||||
|
||||
tokens = self.get_num_tokens(model, credentials, texts)
|
||||
usage = self._calc_response_usage(model, credentials, tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
if "huggingfacehub_api_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
||||
|
||||
if "huggingfacehub_api_token" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.")
|
||||
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
if "huggingface_namespace" not in credentials:
|
||||
raise CredentialsValidateFailedError(
|
||||
"Huggingface Hub User Name / Organization Name must be provided."
|
||||
)
|
||||
|
||||
if "huggingfacehub_endpoint_url" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.")
|
||||
|
||||
if "task_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.")
|
||||
|
||||
if credentials["task_type"] != "feature-extraction":
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.")
|
||||
|
||||
self._check_endpoint_url_model_repository_name(credentials, model)
|
||||
|
||||
model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
|
||||
self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model)
|
||||
else:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
||||
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
client.feature_extraction(text="hello world", model=model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={"context_size": 10000, "max_chunks": 1},
|
||||
)
|
||||
return entity
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
|
||||
# Returned values are a list of floats, or a list[list[floats]]
|
||||
# (depending on if you sent a string or a list of string,
|
||||
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
|
||||
# This should be explained on the model's README.)
|
||||
@staticmethod
|
||||
def _mean_pooling(embeddings: list) -> list[float]:
|
||||
# If automatic reduction by giving model, no need to mean_pooling.
|
||||
# For example one: List[List[float]]
|
||||
if not isinstance(embeddings[0][0], list):
|
||||
return embeddings
|
||||
|
||||
# For example two: List[List[List[float]]], need to mean_pooling.
|
||||
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
|
||||
return sentence_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str) -> None:
|
||||
hf_api = HfApi(token=huggingfacehub_api_token)
|
||||
model_info = hf_api.model_info(repo_id=model_name)
|
||||
|
||||
try:
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {model_name} not found.")
|
||||
|
||||
if "inference" in model_info.cardData and not model_info.cardData["inference"]:
|
||||
raise ValueError(f"Inference API has been turned off for this model {model_name}.")
|
||||
|
||||
valid_tasks = "feature-extraction"
|
||||
if model_info.pipeline_tag not in valid_tasks:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{str(e)}")
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
headers = {
|
||||
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError("User Name or Organization Name is invalid.")
|
||||
|
||||
model_repository_name = ""
|
||||
|
||||
for item in response.json().get("items", []):
|
||||
if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]:
|
||||
model_repository_name = item.get("model", {}).get("repository")
|
||||
break
|
||||
|
||||
if model_repository_name != model_name:
|
||||
raise ValueError(
|
||||
f"Model Name {model_name} is invalid. Please check it on the inference endpoints console."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
|
||||
|
||||
|
||||
class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Text Embedding Inference text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
credentials should be like:
|
||||
{
|
||||
'server_url': 'server url',
|
||||
'model_uid': 'model uid',
|
||||
}
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
server_url = credentials["server_url"]
|
||||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
# get tokenized results from TEI
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
|
||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||
# Check if the number of tokens is larger than the context size
|
||||
num_tokens = len(tokenize_result)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
# Find the best cutoff point
|
||||
pre_special_token_count = 0
|
||||
for token in tokenize_result:
|
||||
if token["special"]:
|
||||
pre_special_token_count += 1
|
||||
else:
|
||||
break
|
||||
rest_special_token_count = (
|
||||
len([token for token in tokenize_result if token["special"]]) - pre_special_token_count
|
||||
)
|
||||
|
||||
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
|
||||
token_cutoff = context_size - rest_special_token_count - 20
|
||||
|
||||
# Find the cutoff index
|
||||
cutpoint_token = tokenize_result[token_cutoff]
|
||||
cutoff = cutpoint_token["start"]
|
||||
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
try:
|
||||
used_tokens = 0
|
||||
for i in _iter:
|
||||
iter_texts = inputs[i : i + max_chunks]
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||
embeddings = results["data"]
|
||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||
batched_embeddings.extend(embeddings)
|
||||
|
||||
usage = results["usage"]
|
||||
used_tokens += usage["total_tokens"]
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
server_url = credentials["server_url"]
|
||||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
print(extra_args)
|
||||
if extra_args.model_type != "embedding":
|
||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||
|
||||
credentials["context_size"] = extra_args.max_input_length
|
||||
credentials["max_chunks"] = extra_args.max_client_batch_size
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Hunyuan text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
if model != "hunyuan-embedding":
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for input in texts:
|
||||
request = models.GetEmbeddingRequest()
|
||||
params = {"Input": input}
|
||||
request.from_json_string(json.dumps(params))
|
||||
response = client.GetEmbedding(request)
|
||||
usage = response.Usage.TotalTokens
|
||||
|
||||
embeddings.extend([data.Embedding for data in response.Data])
|
||||
token_usage += usage
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
try:
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": [{"Role": "user", "Content": "hello"}],
|
||||
"TopP": 1,
|
||||
"Temperature": 0,
|
||||
"Stream": False,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
client.ChatCompletions(req)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
def _setup_hunyuan_client(self, credentials):
|
||||
secret_id = credentials["secret_id"]
|
||||
secret_key = credentials["secret_key"]
|
||||
cred = credential.Credential(secret_id, secret_key)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
clientProfile = ClientProfile()
|
||||
clientProfile.httpProfile = httpProfile
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
return client
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeError: [TencentCloudSDKException],
|
||||
}
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
# client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
# use client.GetTokenCount to get num tokens
|
||||
# request = models.GetTokenCountRequest()
|
||||
# params = {
|
||||
# "Prompt": text
|
||||
# }
|
||||
# request.from_json_string(json.dumps(params))
|
||||
# response = client.GetTokenCount(request)
|
||||
# num_tokens += response.TokenCount
|
||||
|
||||
return num_tokens
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
provider: jina
|
||||
label:
|
||||
en_US: Jina AI
|
||||
description:
|
||||
en_US: Embedding and Rerank Model Supported
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Jina AI
|
||||
zh_Hans: 从 Jina AI 获取 API Key
|
||||
url:
|
||||
en_US: https://jina.ai/
|
||||
supported_model_types:
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: 服务器 URL
|
||||
en_US: Base URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. https://api.jina.ai/v1
|
||||
en_US: Base URL, e.g. https://api.jina.ai/v1
|
||||
default: 'https://api.jina.ai/v1'
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
placeholder:
|
||||
zh_Hans: 输入上下文大小
|
||||
en_US: Enter context size
|
||||
required: false
|
||||
type: text-input
|
||||
default: '8192'
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer
|
||||
|
||||
|
||||
class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Jina text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.jina.ai/v1"
|
||||
|
||||
def _to_payload(self, model: str, texts: list[str], credentials: dict, input_type: EmbeddingInputType) -> dict:
|
||||
"""
|
||||
Parse model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: parsed credentials
|
||||
"""
|
||||
|
||||
def transform_jina_input_text(model, text):
|
||||
if model == "jina-clip-v1":
|
||||
return {"text": text}
|
||||
return text
|
||||
|
||||
data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]}
|
||||
|
||||
# model specific parameters
|
||||
if model == "jina-embeddings-v3":
|
||||
# set `task` type according to input type for the best performance
|
||||
data["task"] = "retrieval.query" if input_type == EmbeddingInputType.QUERY else "retrieval.passage"
|
||||
|
||||
return data
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
base_url = credentials.get("base_url", self.api_base)
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
url = base_url + "/embeddings"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = self._to_payload(model=model, texts=texts, credentials=credentials, input_type=input_type)
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use JinaTokenizer to get num tokens
|
||||
num_tokens += JinaTokenizer.get_num_tokens(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -1,189 +0,0 @@
|
|||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for LocalAI text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
if len(texts) != 1:
|
||||
raise InvokeBadRequestError("Only one text is supported")
|
||||
|
||||
server_url = credentials["server_url"]
|
||||
model_name = model
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError("server_url is required")
|
||||
if not model_name:
|
||||
raise CredentialsValidateFailedError("model_name is required")
|
||||
|
||||
url = server_url
|
||||
headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": model_name, "input": texts[0]}
|
||||
|
||||
try:
|
||||
response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10)
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
code = resp["error"]["code"]
|
||||
msg = resp["error"]["message"]
|
||||
if code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError("Invalid credentials")
|
||||
except InvokeConnectionError as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
import time
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.minimax.llm.errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
|
||||
class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Minimax text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.minimax.chat/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["minimax_api_key"]
|
||||
group_id = credentials["minimax_group_id"]
|
||||
if model != "embo-01":
|
||||
raise ValueError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
url = f"{self.api_base}?GroupId={group_id}"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeServerUnavailableError(response.text)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
# check if there is an error
|
||||
if resp["base_resp"]["status_code"] != 0:
|
||||
code = resp["base_resp"]["status_code"]
|
||||
msg = resp["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = resp["vectors"]
|
||||
total_tokens = resp["total_tokens"]
|
||||
except InvalidAuthenticationError:
|
||||
raise InvalidAPIKeyError("Invalid api key")
|
||||
except KeyError as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens)
|
||||
|
||||
result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use MinimaxTokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvalidAPIKeyError:
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code in {1000, 1001}:
|
||||
raise InternalServerError(msg)
|
||||
elif code == 1002:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif code == 1004:
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif code == 1008:
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif code == 2013:
|
||||
raise BadRequestError(msg)
|
||||
else:
|
||||
raise InternalServerError(msg)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for MixedBread text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.mixedbread.ai/v1"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
base_url = credentials.get("base_url", self.api_base)
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
url = base_url + "/embeddings"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": model, "input": texts}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -13,7 +13,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.text2img_model import Text2ImageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
|
|
@ -284,8 +283,6 @@ class ModelProviderFactory:
|
|||
return ModerationModel(**init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel(**init_params)
|
||||
elif model_type == ModelType.TEXT2IMG:
|
||||
return Text2ImageModel(**init_params)
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> bytes:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,165 +0,0 @@
|
|||
import time
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from nomic import embed
|
||||
from nomic import login as nomic_login
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.nomic._common import _CommonNomic
|
||||
|
||||
|
||||
def nomic_login_required(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
if not kwargs.get("credentials"):
|
||||
raise ValueError("missing credentials parameters")
|
||||
credentials = kwargs.get("credentials")
|
||||
if "nomic_api_key" not in credentials:
|
||||
raise ValueError("missing nomic_api_key in credentials parameters")
|
||||
# nomic login
|
||||
nomic_login(credentials["nomic_api_key"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for nomic text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
embeddings, prompt_tokens, total_tokens = self.embed_text(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, tokens=prompt_tokens, total_tokens=total_tokens
|
||||
)
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self.embed_text(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@nomic_login_required
|
||||
def embed_text(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int, int]:
|
||||
"""Call out to Nomic's embedding endpoint.
|
||||
|
||||
Args:
|
||||
model: The model to use for embedding.
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text, and tokens usage.
|
||||
"""
|
||||
embeddings: list[list[float]] = []
|
||||
prompt_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
response = embed.text(
|
||||
model=model,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
if not (response and "embeddings" in response):
|
||||
raise ValueError("Embedding data is missing in the response.")
|
||||
|
||||
if not (response and "usage" in response):
|
||||
raise ValueError("Response usage is missing.")
|
||||
|
||||
if "prompt_tokens" not in response["usage"]:
|
||||
raise ValueError("Response usage does not contain prompt tokens.")
|
||||
|
||||
if "total_tokens" not in response["usage"]:
|
||||
raise ValueError("Response usage does not contain total tokens.")
|
||||
|
||||
embeddings = [list(map(float, e)) for e in response["embeddings"]]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
prompt_tokens = response["usage"]["prompt_tokens"]
|
||||
return embeddings, prompt_tokens, total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: prompt tokens
|
||||
:param total_tokens: total tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=total_tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Nvidia text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
||||
models: list[str] = ["NV-Embed-QA"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if model not in self.models:
|
||||
raise InvokeBadRequestError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
url = self.api_base
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": model, "input": texts[0], "input_type": "query"}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use JinaTokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"},
|
||||
"truncate": "NONE",
|
||||
"inputs": [""],
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": "",
|
||||
}
|
||||
|
||||
|
||||
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, credentials=credentials, texts=inputs[i : i + max_chunks]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
characters = 0
|
||||
for text in texts:
|
||||
characters += len(text)
|
||||
return characters
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
|
||||
# oci
|
||||
# initialize client
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8")
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError(
|
||||
"oci_config_content should be base64.b64encode("
|
||||
"'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||
)
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8")
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
request_args["inputs"] = texts
|
||||
response = client.embed_text(request_args)
|
||||
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
|
@ -1,211 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an Ollama text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = credentials.get("base_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "api/embed")
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
|
||||
inputs = []
|
||||
used_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
|
||||
# Prepare the payload for the request
|
||||
payload = {"input": inputs, "model": model, "options": {"use_mmap": True}}
|
||||
|
||||
# Make the request to the Ollama API
|
||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings = response_data["embeddings"]
|
||||
embedding_used_tokens = self.get_num_tokens(model, credentials, inputs)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeAuthorizationError: [
|
||||
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
|
||||
requests.exceptions.InvalidURL, # Misconfigured request or other API error
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
requests.exceptions.RetryError # Too many requests sent in a short period of time
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||
requests.exceptions.HTTPError, # Server Error
|
||||
],
|
||||
InvokeConnectionError: [
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout, # Timeout
|
||||
],
|
||||
}
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
import base64
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
|
||||
class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for OpenAI text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
# init model client
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs["encoding_format"] = "base64"
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
token = enc.encode(text)
|
||||
for j in range(0, len(token), context_size):
|
||||
tokens += [token[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = enc.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(
|
||||
self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to embed
|
||||
:param extra_model_kwargs: extra model kwargs
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
# call embedding model
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
model=model,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64":
|
||||
# decode base64 embedding
|
||||
return (
|
||||
[list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
||||
response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,217 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs["encoding_format"] = "float"
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# Prepare the payload for the request
|
||||
payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings_batch = [data["embedding"] for data in response_data["data"]]
|
||||
embedding_used_tokens = response_data["usage"]["total_tokens"]
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
payload = {"input": "ping", "model": model}
|
||||
|
||||
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
||||
|
||||
if "model" not in json_result:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
import time
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for OpenLLM text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
server_url = credentials["server_url"]
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError("server_url is required")
|
||||
|
||||
headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||
|
||||
url = f"{server_url}/v1/embeddings"
|
||||
|
||||
data = texts
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except (ConnectionError, InvalidSchema, MissingSchema) as e:
|
||||
# cloud not connect to the server
|
||||
raise InvokeAuthorizationError(f"Invalid server URL: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.status_code == 400:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
elif response.status_code == 404:
|
||||
raise InvokeAuthorizationError(response.text)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(response.text)
|
||||
|
||||
try:
|
||||
resp = response.json()[0]
|
||||
embeddings = resp["embeddings"]
|
||||
total_tokens = resp["num_tokens"]
|
||||
except KeyError as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens)
|
||||
|
||||
result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError("Invalid server_url")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.replicate._common import _CommonReplicate
|
||||
|
||||
|
||||
class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30)
|
||||
|
||||
if "model_version" in credentials:
|
||||
model_version = credentials["model_version"]
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f"{model}:{model_version}"
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts)
|
||||
|
||||
tokens = self.get_num_tokens(model, credentials, texts)
|
||||
usage = self._calc_response_usage(model, credentials, tokens)
|
||||
|
||||
return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if "replicate_api_token" not in credentials:
|
||||
raise CredentialsValidateFailedError("Replicate Access Token must be provided.")
|
||||
|
||||
try:
|
||||
client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30)
|
||||
|
||||
if "model_version" in credentials:
|
||||
model_version = credentials["model_version"]
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f"{model}:{model_version}"
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
self._generate_embeddings_by_text_input_key(
|
||||
client, replicate_model_version, text_input_key, ["Hello worlds!"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={"context_size": 4096, "max_chunks": 1},
|
||||
)
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) -> str:
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
|
||||
# sort through the openapi schema to get the name of text, texts or inputs
|
||||
input_properties = sorted(
|
||||
model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
|
||||
for input_property in input_properties:
|
||||
if input_property[0] in {"text", "texts", "inputs"}:
|
||||
text_input_key = input_property[0]
|
||||
return text_input_key
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _generate_embeddings_by_text_input_key(
|
||||
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
|
||||
) -> list[list[float]]:
|
||||
if text_input_key in {"text", "inputs"}:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = client.run(replicate_model_version, input={text_input_key: text})
|
||||
embeddings.append(result[0].get("embedding"))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
elif "texts" == text_input_key:
|
||||
result = client.run(
|
||||
replicate_model_version,
|
||||
input={
|
||||
"texts": json.dumps(texts),
|
||||
"batch_size": 4,
|
||||
"convert_to_numpy": False,
|
||||
"normalize_embeddings": True,
|
||||
},
|
||||
)
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"embeddings input key is invalid: {text_input_key}")
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,463 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
import boto3
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False):
|
||||
"""
|
||||
params:
|
||||
predictor : Sagemaker Predictor
|
||||
messages (List[Dict[str,Any]]): message list。
|
||||
messages = [
|
||||
{"role": "system", "content":"please answer in Chinese"},
|
||||
{"role": "user", "content": "who are you? what are you doing?"},
|
||||
]
|
||||
params (Dict[str,Any]): model parameters for LLM。
|
||||
stream (bool): False by default。
|
||||
|
||||
response:
|
||||
result of inference if stream is False
|
||||
Iterator of Chunks if stream is True
|
||||
"""
|
||||
payload = {
|
||||
"model": params.get("model_name"),
|
||||
"stop": stop,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)),
|
||||
"temperature": params.get("temperature", 0.1),
|
||||
"top_p": params.get("top_p", 0.9),
|
||||
}
|
||||
|
||||
if not stream:
|
||||
response = predictor.predict(payload)
|
||||
return response
|
||||
else:
|
||||
response_stream = predictor.predict_stream(payload)
|
||||
return response_stream
|
||||
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
|
||||
sagemaker_session: Any = None
|
||||
predictor: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: bytes,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
resp_obj = json.loads(resp.decode("utf-8"))
|
||||
resp_str = resp_obj.get("choices")[0].get("message").get("content")
|
||||
|
||||
if len(resp_str) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[])
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
usage=usage,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Iterator[bytes],
|
||||
) -> Generator:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
"""
|
||||
full_response = ""
|
||||
buffer = ""
|
||||
for chunk_bytes in resp:
|
||||
buffer += chunk_bytes.decode("utf-8")
|
||||
last_idx = 0
|
||||
for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
|
||||
try:
|
||||
data = json.loads(match.group(1).strip())
|
||||
last_idx = match.span()[1]
|
||||
|
||||
if "content" in data["choices"][0]["delta"]:
|
||||
chunk_content = data["choices"][0]["delta"]["content"]
|
||||
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
|
||||
|
||||
if data["choices"][0]["finish_reason"] is not None:
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
messages=[temp_assistant_prompt_message], tools=[]
|
||||
)
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=data["choices"][0]["finish_reason"],
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
|
||||
)
|
||||
|
||||
full_response += chunk_content
|
||||
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
|
||||
pass
|
||||
|
||||
buffer = buffer[last_idx:]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not self.sagemaker_session:
|
||||
access_key = credentials.get("aws_access_key_id")
|
||||
secret_key = credentials.get("aws_secret_access_key")
|
||||
aws_region = credentials.get("aws_region")
|
||||
boto_session = None
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
boto_session = boto3.Session(
|
||||
aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region
|
||||
)
|
||||
else:
|
||||
boto_session = boto3.Session(region_name=aws_region)
|
||||
else:
|
||||
boto_session = boto3.Session()
|
||||
|
||||
sagemaker_client = boto_session.client("sagemaker")
|
||||
self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
|
||||
|
||||
if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
|
||||
self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
||||
self.predictor = Predictor(
|
||||
endpoint_name=self.sagemaker_endpoint,
|
||||
sagemaker_session=self.sagemaker_session,
|
||||
serializer=serializers.JSONSerializer(),
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages]
|
||||
response = inference(
|
||||
predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream
|
||||
)
|
||||
|
||||
if stream:
|
||||
if tools and len(tools) > 0:
|
||||
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
|
||||
|
||||
return self._handle_chat_stream_response(
|
||||
model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response
|
||||
)
|
||||
return self._handle_chat_generate_response(
|
||||
model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response
|
||||
)
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False
|
||||
) -> int:
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if is_completion_model:
|
||||
return sum(tokens(str(message.content)) for message in messages)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
if key == "function_call":
|
||||
for t_key, t_value in value.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
else:
|
||||
num_tokens += tokens(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
try:
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
pass
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template="temperature",
|
||||
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template="top_p",
|
||||
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
type=ParameterType.INT,
|
||||
use_template="max_tokens",
|
||||
min=1,
|
||||
max=credentials.get("context_length", 2048),
|
||||
default=512,
|
||||
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
||||
),
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"]).value
|
||||
|
||||
features = []
|
||||
|
||||
support_function_call = credentials.get("support_function_call", False)
|
||||
if support_function_call:
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
|
||||
support_vision = credentials.get("support_vision", False)
|
||||
if support_vision:
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
context_length = credentials.get("context_length", 2048)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=features,
|
||||
model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length},
|
||||
parameter_rules=rules,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -1,200 +0,0 @@
|
|||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
BATCH_SIZE = 20
|
||||
CONTEXT_SIZE = 8192
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def batch_generator(generator, batch_size):
|
||||
while True:
|
||||
batch = list(itertools.islice(generator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
|
||||
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]):
|
||||
response_model = sm_client.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
embeddings = json_obj["embeddings"]
|
||||
return embeddings
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
try:
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get("aws_access_key_id")
|
||||
secret_key = credentials.get("aws_secret_access_key")
|
||||
aws_region = credentials.get("aws_region")
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client(
|
||||
"sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region,
|
||||
)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
||||
|
||||
line = 3
|
||||
truncated_texts = [item[:CONTEXT_SIZE] for item in texts]
|
||||
|
||||
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||
all_embeddings = []
|
||||
|
||||
line = 4
|
||||
for batch in batches:
|
||||
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
line = 5
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=0, # It's not SAAS API, usage is meaningless
|
||||
)
|
||||
line = 6
|
||||
|
||||
return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception {e}, line : {line}")
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
print("validate_credentials ok....")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
- Qwen/Qwen2.5-72B-Instruct
|
||||
- Qwen/Qwen2.5-32B-Instruct
|
||||
- Qwen/Qwen2.5-14B-Instruct
|
||||
- Qwen/Qwen2.5-7B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
- Qwen/Qwen2.5-Math-72B-Instruct
|
||||
- Qwen/Qwen2-72B-Instruct
|
||||
- Qwen/Qwen2-57B-A14B-Instruct
|
||||
- Qwen/Qwen2-7B-Instruct
|
||||
- Qwen/Qwen2-1.5B-Instruct
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
- deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
- THUDM/glm-4-9b-chat
|
||||
- 01-ai/Yi-1.5-34B-Chat-16K
|
||||
- 01-ai/Yi-1.5-9B-Chat-16K
|
||||
- 01-ai/Yi-1.5-6B-Chat
|
||||
- internlm/internlm2_5-20b-chat
|
||||
- internlm/internlm2_5-7b-chat
|
||||
- meta-llama/Meta-Llama-3.1-405B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
- meta-llama/Meta-Llama-3-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3-8B-Instruct
|
||||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-9b-it
|
||||
- mistralai/Mistral-7B-Instruct-v0.2
|
||||
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
model: internlm/internlm2_5-20b-chat
|
||||
label:
|
||||
en_US: internlm/internlm2_5-20b-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '1'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
model: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
model: Qwen/Qwen2.5-Math-72B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-Math-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
"""
|
||||
Model class for Siliconflow text embedding model.
|
||||
"""
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, texts, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super().get_num_tokens(model, credentials, texts)
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
|
||||
|
|
@ -1,309 +0,0 @@
|
|||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
from ._client import SparkLLMClient
|
||||
|
||||
|
||||
class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(content="ping"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.5,
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs["stop_sequences"] = stop
|
||||
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
client = SparkLLMClient(
|
||||
model=model,
|
||||
**credentials_kwargs,
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=client.run,
|
||||
args=(
|
||||
[
|
||||
{"role": prompt_message.role.value, "content": prompt_message.content}
|
||||
for prompt_message in prompt_messages
|
||||
],
|
||||
user,
|
||||
model_parameters,
|
||||
stream,
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(thread, model, credentials, client, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
thread: threading.Thread,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
client: SparkLLMClient,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
completion = ""
|
||||
|
||||
for content in client.subscribe():
|
||||
if isinstance(content, dict):
|
||||
delta = content["data"]
|
||||
else:
|
||||
delta = content
|
||||
|
||||
completion += delta
|
||||
|
||||
thread.join()
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=completion)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self,
|
||||
thread: threading.Thread,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
client: SparkLLMClient,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param thread: thread
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
completion = ""
|
||||
for index, content in enumerate(client.subscribe()):
|
||||
if isinstance(content, dict):
|
||||
delta = content["data"]
|
||||
else:
|
||||
delta = content
|
||||
completion += delta
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta or "",
|
||||
)
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=completion,
|
||||
)
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [temp_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage),
|
||||
)
|
||||
|
||||
thread.join()
|
||||
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"app_id": credentials["app_id"],
|
||||
"api_secret": credentials["api_secret"],
|
||||
"api_key": credentials["api_key"],
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nHuman:"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
content = message.content
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
import numpy as np
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
|
||||
|
||||
class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Tongyi text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings_batch, embedding_used_tokens = self.embed_documents(
|
||||
credentials_kwargs=credentials_kwargs,
|
||||
model=model,
|
||||
texts=inputs[i : i + max_chunks],
|
||||
)
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# call embedding model
|
||||
self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""Call out to Tongyi's embedding endpoint.
|
||||
|
||||
Args:
|
||||
credentials_kwargs: The credentials to use for the call.
|
||||
model: The model to use for embedding.
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text, and tokens usage.
|
||||
"""
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
for text in texts:
|
||||
response = dashscope.TextEmbedding.call(
|
||||
api_key=credentials_kwargs["dashscope_api_key"],
|
||||
model=model,
|
||||
input=text,
|
||||
text_type="document",
|
||||
)
|
||||
if response.output and "embeddings" in response.output and response.output["embeddings"]:
|
||||
data = response.output["embeddings"][0]
|
||||
if "embedding" in data:
|
||||
embeddings.append(data["embedding"])
|
||||
else:
|
||||
raise ValueError("Embedding data is missing in the response.")
|
||||
else:
|
||||
raise ValueError("Response output is missing or does not contain embeddings.")
|
||||
|
||||
if response.usage and "total_tokens" in response.usage:
|
||||
embedding_used_tokens += response.usage["total_tokens"]
|
||||
else:
|
||||
raise ValueError("Response usage is missing or does not contain total tokens.")
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
import base64
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.upstage._common import _CommonUpstage
|
||||
|
||||
|
||||
class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Upstage text embedding model.
|
||||
"""
|
||||
|
||||
def _get_tokenizer(self) -> Tokenizer:
|
||||
return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
extra_model_kwargs["encoding_format"] = "base64"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
token = tokenizer.encode(text, add_special_tokens=False).tokens
|
||||
for j in range(0, len(token), context_size):
|
||||
tokens += [token[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=tokens[i : i + max_chunks],
|
||||
extra_model_kwargs=extra_model_kwargs,
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[texts[i]],
|
||||
extra_model_kwargs=extra_model_kwargs,
|
||||
)
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
tokenizer = self._get_tokenizer()
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = tokenizer.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(
|
||||
self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to embed
|
||||
:param extra_model_kwargs: extra model kwargs
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs)
|
||||
|
||||
if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64":
|
||||
return (
|
||||
[
|
||||
list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32))
|
||||
for embedding in response.data
|
||||
],
|
||||
response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT
|
||||
)
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-1.5-flash-001
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-1.5-flash-002
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 002
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-1.5-pro-001
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-1.5-pro-002
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 002
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-flash-experimental
|
||||
label:
|
||||
en_US: Gemini Flash Experimental
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
model: gemini-pro-experimental
|
||||
label:
|
||||
en_US: Gemini Pro Experimental
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -1,733 +0,0 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.auth.transport.requests
|
||||
import vertexai.generative_models as glm
|
||||
from anthropic import AnthropicVertex, Stream
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke anthropic models via anthropic official SDK
|
||||
if "claude" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
# invoke Gemini model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_anthropic(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Anthropic large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
token = ""
|
||||
|
||||
# get access token from service account credential
|
||||
if service_account_info:
|
||||
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
token = credentials.token
|
||||
|
||||
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available
|
||||
# in us-central1 region
|
||||
if "opus" in model or "claude-3-5-sonnet" in model:
|
||||
location = "us-east5"
|
||||
else:
|
||||
location = "us-central1"
|
||||
|
||||
# use access token to authenticate
|
||||
if token:
|
||||
client = AnthropicVertex(region=location, project_id=project_id, access_token=token)
|
||||
# When access token is empty, try to use the Google Cloud VM's built-in service account
|
||||
# or the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||
else:
|
||||
client = AnthropicVertex(
|
||||
region=location,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs["stop_sequences"] = stop
|
||||
|
||||
system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
|
||||
|
||||
if system:
|
||||
extra_model_kwargs["system"] = system
|
||||
|
||||
response = client.messages.create(
|
||||
model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_claude_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_claude_response(
|
||||
self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_claude_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
try:
|
||||
full_assistant_content = ""
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, MessageStartEvent):
|
||||
return_model = chunk.message.model
|
||||
input_tokens = chunk.message.usage.input_tokens
|
||||
elif isinstance(chunk, MessageDeltaEvent):
|
||||
output_tokens = chunk.usage.output_tokens
|
||||
finish_reason = chunk.delta.stop_reason
|
||||
elif isinstance(chunk, MessageStopEvent):
|
||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text or ""
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text or "",
|
||||
)
|
||||
index = chunk.index
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
def _calc_claude_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
|
||||
system = ""
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content = message.content.strip()
|
||||
if first_loop:
|
||||
system = message.content
|
||||
first_loop = False
|
||||
else:
|
||||
system += "\n"
|
||||
system += message.content
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
)
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = gml.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
return glm.Tool(
|
||||
function_declarations=[
|
||||
glm.FunctionDeclaration(
|
||||
name=tool.name,
|
||||
parameters=glm.Schema(
|
||||
type=glm.Type.OBJECT,
|
||||
properties={
|
||||
key: {
|
||||
"type_": value.get("type", "string").upper(),
|
||||
"description": value.get("description", ""),
|
||||
"enum": value.get("enum", []),
|
||||
}
|
||||
for key, value in tool.parameters.get("properties", {}).items()
|
||||
},
|
||||
required=tool.parameters.get("required", []),
|
||||
),
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
|
||||
history = []
|
||||
system_instruction = ""
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-1.0-pro-vision-001":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages:
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_instruction = msg.content
|
||||
else:
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1].role == content.role:
|
||||
history[-1].parts.extend(content.parts)
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction)
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=glm.GenerationConfig(**config_kwargs),
|
||||
stream=stream,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
for chunk in response:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
||||
|
||||
if part.text:
|
||||
assistant_prompt_message.content += part.text
|
||||
|
||||
if part.function_call:
|
||||
assistant_prompt_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.function_call.name,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps(dict(part.function_call.args.items())),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
index += 1
|
||||
|
||||
if not hasattr(chunk, "finish_reason") or not chunk.finish_reason:
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk.candidates[0].finish_reason,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = glm.Content(role="user", parts=[])
|
||||
|
||||
if isinstance(message.content, str):
|
||||
glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
|
||||
else:
|
||||
parts = []
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
parts.append(glm.Part.from_text(c.data))
|
||||
else:
|
||||
metadata, data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
parts.append(glm.Part.from_data(mime_type=mime_type, data=data))
|
||||
glm_content = glm.Content(role="user", parts=parts)
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if message.content:
|
||||
glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
|
||||
if message.tool_calls:
|
||||
glm_content = glm.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
glm.Part.from_function_response(
|
||||
glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
glm_content = glm.Content(
|
||||
role="function",
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name, response={"response": message.content}
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
return glm_content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = gml.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = gml.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = gml.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [exceptions.RetryError],
|
||||
InvokeServerUnavailableError: [
|
||||
exceptions.ServiceUnavailable,
|
||||
exceptions.InternalServerError,
|
||||
exceptions.BadGateway,
|
||||
exceptions.GatewayTimeout,
|
||||
exceptions.DeadlineExceeded,
|
||||
],
|
||||
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
|
||||
InvokeAuthorizationError: [
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.PermissionDenied,
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.Forbidden,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
exceptions.BadRequest,
|
||||
exceptions.InvalidArgument,
|
||||
exceptions.FailedPrecondition,
|
||||
exceptions.OutOfRange,
|
||||
exceptions.NotFound,
|
||||
exceptions.MethodNotAllowed,
|
||||
exceptions.Conflict,
|
||||
exceptions.AlreadyExists,
|
||||
exceptions.Aborted,
|
||||
exceptions.LengthRequired,
|
||||
exceptions.PreconditionFailed,
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
],
|
||||
}
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
|
||||
|
||||
|
||||
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Vertex AI text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
|
||||
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts)
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = enc.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
|
||||
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
# call embedding model
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
response = client.get_embeddings(texts)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for i in range(len(response)):
|
||||
embeddings.append(response[i].values)
|
||||
token_usage += int(response[i].statistics.token_count)
|
||||
|
||||
return embeddings, token_usage
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -1,198 +0,0 @@
|
|||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasError,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
||||
|
||||
|
||||
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for VolcengineMaaS text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._generate_v2(model, credentials, texts, user)
|
||||
|
||||
return self._generate_v3(model, credentials, texts, user)
|
||||
|
||||
def _generate_v2(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
def _generate_v3(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
client = ArkClientV3.from_credentials(credentials)
|
||||
resp = client.embeddings(texts)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens)
|
||||
|
||||
result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._validate_credentials_v2(model, credentials)
|
||||
return self._validate_credentials_v3(model, credentials)
|
||||
|
||||
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except MaasError as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(e)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: ConnectionErrors.values(),
|
||||
InvokeServerUnavailableError: ServerUnavailableErrors.values(),
|
||||
InvokeRateLimitError: RateLimitErrors.values(),
|
||||
InvokeAuthorizationError: AuthErrors.values(),
|
||||
InvokeBadRequestError: BadRequestErrors.values(),
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
model_config = get_model_config(credentials)
|
||||
model_properties = {
|
||||
ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
|
||||
ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks,
|
||||
}
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from json import dumps
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from requests import Response, post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
invoke_error_mapping,
|
||||
)
|
||||
|
||||
|
||||
class TextEmbedding:
|
||||
@abstractmethod
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
access_token = self._get_access_token()
|
||||
url = f"{self.api_bases[model]}?access_token={access_token}"
|
||||
body = self._build_embed_request_body(model, texts, user)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
resp = post(url, data=dumps(body), headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}")
|
||||
return self._handle_embed_response(model, resp)
|
||||
|
||||
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
|
||||
if len(texts) == 0:
|
||||
raise BadRequestError("The number of texts should not be zero.")
|
||||
body = {
|
||||
"input": texts,
|
||||
"user_id": user,
|
||||
}
|
||||
return body
|
||||
|
||||
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||
data = response.json()
|
||||
if "error_code" in data:
|
||||
code = data["error_code"]
|
||||
msg = data["error_msg"]
|
||||
# raise error
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = [v["embedding"] for v in data["data"]]
|
||||
_usage = data["usage"]
|
||||
tokens = _usage["prompt_tokens"]
|
||||
total_tokens = _usage["total_tokens"]
|
||||
|
||||
return embeddings, tokens, total_tokens
|
||||
|
||||
|
||||
class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return WenxinTextEmbedding(api_key, secret_key)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
api_key = credentials["api_key"]
|
||||
secret_key = credentials["secret_key"]
|
||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||
user = user or "ErnieBotDefault"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
used_total_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
for i in _iter:
|
||||
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
|
||||
model, inputs[i : i + max_chunks], user
|
||||
)
|
||||
used_tokens += _used_tokens
|
||||
used_total_tokens += _total_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
|
||||
return TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
api_key = credentials["api_key"]
|
||||
secret_key = credentials["secret_key"]
|
||||
try:
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return invoke_error_mapping()
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=total_tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
|
||||
|
||||
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Xinference text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
credentials should be like:
|
||||
{
|
||||
'server_url': 'server url',
|
||||
'model_uid': 'model uid',
|
||||
}
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
server_url = credentials["server_url"]
|
||||
model_uid = credentials["model_uid"]
|
||||
api_key = credentials.get("api_key")
|
||||
server_url = server_url.removesuffix("/")
|
||||
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
"""
|
||||
for convenience, the response json is like:
|
||||
class Embedding(TypedDict):
|
||||
object: Literal["list"]
|
||||
model: str
|
||||
data: List[EmbeddingData]
|
||||
usage: EmbeddingUsage
|
||||
class EmbeddingUsage(TypedDict):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
class EmbeddingData(TypedDict):
|
||||
index: int
|
||||
object: str
|
||||
embedding: List[float]
|
||||
"""
|
||||
|
||||
usage = embeddings["usage"]
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
server_url = credentials["server_url"]
|
||||
model_uid = credentials["model_uid"]
|
||||
api_key = credentials.get("api_key")
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=server_url,
|
||||
model_uid=model_uid,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials["max_tokens"] = extra_args.max_tokens
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
client = Client(
|
||||
base_url=server_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
"please check model type, the model you want to invoke is not a text embedding model"
|
||||
)
|
||||
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError as e:
|
||||
raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}")
|
||||
except RuntimeError as e:
|
||||
raise CredentialsValidateFailedError(e)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512,
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -235,6 +235,7 @@ class PluginModelManager(BasePluginManager):
|
|||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding
|
||||
|
|
@ -252,6 +253,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
),
|
||||
|
|
@ -272,7 +274,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
|
|
@ -289,7 +290,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
|
|
@ -313,7 +314,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
|
|
@ -333,7 +333,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"query": query,
|
||||
|
|
@ -360,7 +360,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
|
|
@ -378,7 +377,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"content_text": content_text,
|
||||
|
|
@ -405,7 +404,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
language: Optional[str] = None,
|
||||
|
|
@ -422,7 +420,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
|
|
@ -447,7 +445,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
|
|
@ -464,7 +461,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
|
|
@ -488,7 +485,6 @@ class PluginModelManager(BasePluginManager):
|
|||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
text: str,
|
||||
|
|
@ -505,7 +501,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
|
|
|
|||
|
|
@ -244,12 +244,11 @@ class ProviderManager:
|
|||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model,
|
||||
)
|
||||
default_model = TenantDefaultModel()
|
||||
default_model.tenant_id = tenant_id
|
||||
default_model.model_type = model_type.to_origin_model_type()
|
||||
default_model.provider_name = available_model.provider.provider
|
||||
default_model.model_name = available_model.model
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -489,15 +488,14 @@ class ProviderManager:
|
|||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
provider_record = Provider()
|
||||
provider_record.tenant_id = tenant_id
|
||||
provider_record.provider_name = provider_name
|
||||
provider_record.provider_type = ProviderType.SYSTEM.value
|
||||
provider_record.quota_type = ProviderQuotaType.TRIAL.value
|
||||
provider_record.quota_limit = quota.quota_limit
|
||||
provider_record.quota_used = 0
|
||||
provider_record.is_valid = True
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from core.model_runtime.entities.message_entities import PromptMessage, SystemPr
|
|||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
|
||||
|
|
@ -124,9 +123,3 @@ class BuiltinTool(Tool):
|
|||
return self.summary(user_id=user_id, content=result)
|
||||
|
||||
return result
|
||||
|
||||
def get_url(self, url: str, user_agent: str | None = None) -> str:
|
||||
"""
|
||||
get url
|
||||
"""
|
||||
return get_url(url, user_agent=user_agent)
|
||||
|
|
|
|||
|
|
@ -1,357 +0,0 @@
|
|||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from regex import regex
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
except (UnicodeDecodeError, TypeError):
|
||||
content = response.text
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
a = extract_using_readabilipy(content)
|
||||
|
||||
if not a["plain_text"] or not a["plain_text"].strip():
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a["title"],
|
||||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a["plain_text"] or "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None,
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if input_json.get("title"):
|
||||
article_json["title"] = input_json["title"]
|
||||
if input_json.get("byline"):
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if input_json.get("date"):
|
||||
article_json["date"] = input_json["date"]
|
||||
if input_json.get("content"):
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if input_json.get("textContent"):
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(["ul", "ol"])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(
|
||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
||||
)
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalize it
|
||||
plain_text = normalize_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, "html.parser")
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalize the extracted text string to a canonical representation
|
||||
plain_text = normalize_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalize_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
"""Normalize unicode and whitespace."""
|
||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
||||
text = strip_control_characters(text)
|
||||
text = normalize_unicode(text)
|
||||
text = normalize_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(
|
||||
[
|
||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
||||
for char in text
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalize_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return element.name in {"p", "li"}
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode("utf-8"))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -134,7 +134,6 @@ package-mode = false
|
|||
############################################################
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
anthropic = "~0.23.1"
|
||||
authlib = "1.3.1"
|
||||
azure-identity = "1.16.1"
|
||||
azure-storage-blob = "12.13.0"
|
||||
|
|
@ -145,10 +144,8 @@ bs4 = "~0.0.1"
|
|||
cachetools = "~5.3.0"
|
||||
celery = "~5.3.6"
|
||||
chardet = "~5.1.0"
|
||||
cohere = "~5.2.4"
|
||||
cos-python-sdk-v5 = "1.9.30"
|
||||
esdk-obs-python = "3.24.6.1"
|
||||
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
|
||||
flask = "~3.0.1"
|
||||
flask-compress = "~1.14"
|
||||
flask-cors = "~4.0.0"
|
||||
|
|
@ -169,13 +166,12 @@ google-generativeai = "0.8.1"
|
|||
googleapis-common-protos = "1.63.0"
|
||||
gunicorn = "~22.0.0"
|
||||
httpx = { version = "~0.27.0", extras = ["socks"] }
|
||||
huggingface-hub = "~0.16.4"
|
||||
jieba = "0.42.1"
|
||||
langfuse = "^2.48.0"
|
||||
langsmith = "^0.1.77"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
markdown = "~3.5.1"
|
||||
novita-client = "^0.5.7"
|
||||
nltk = "3.8.1"
|
||||
numpy = "~1.26.4"
|
||||
openai = "~1.29.0"
|
||||
openpyxl = "~3.1.5"
|
||||
|
|
@ -192,9 +188,7 @@ python = ">=3.10,<3.13"
|
|||
python-docx = "~1.1.0"
|
||||
python-dotenv = "1.0.0"
|
||||
pyyaml = "~6.0.1"
|
||||
readabilipy = "0.2.0"
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
replicate = "~0.22.0"
|
||||
resend = "~0.7.0"
|
||||
scikit-learn = "^1.5.1"
|
||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||
|
|
@ -202,21 +196,15 @@ sqlalchemy = "~2.0.29"
|
|||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
tiktoken = "~0.7.0"
|
||||
tokenizers = "~0.15.0"
|
||||
transformers = "~4.35.0"
|
||||
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
|
||||
websocket-client = "~1.7.0"
|
||||
werkzeug = "~3.0.1"
|
||||
xinference-client = "0.15.2"
|
||||
yarl = "~1.9.4"
|
||||
zhipuai = "1.0.7"
|
||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||
|
||||
############################################################
|
||||
# Related transparent dependencies with pinned version
|
||||
# required by main implementations
|
||||
############################################################
|
||||
azure-ai-ml = "^1.19.0"
|
||||
azure-ai-inference = "^1.0.0b3"
|
||||
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
||||
oci = "^2.133.0"
|
||||
tos = "^2.7.1"
|
||||
|
|
@ -231,20 +219,7 @@ safetensors = "~0.4.3"
|
|||
############################################################
|
||||
|
||||
[tool.poetry.group.tool.dependencies]
|
||||
arxiv = "2.1.0"
|
||||
cloudscraper = "1.2.71"
|
||||
matplotlib = "~3.8.2"
|
||||
newspaper3k = "0.2.8"
|
||||
duckduckgo-search = "^6.2.6"
|
||||
jsonpath-ng = "1.6.1"
|
||||
numexpr = "~2.9.0"
|
||||
opensearch-py = "2.4.0"
|
||||
qrcode = "~7.4.2"
|
||||
twilio = "~9.0.4"
|
||||
vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
|
||||
wikipedia = "1.4.0"
|
||||
yfinance = "~0.2.40"
|
||||
nltk = "3.8.1"
|
||||
############################################################
|
||||
# VDB dependencies required by vector store clients
|
||||
############################################################
|
||||
|
|
|
|||
|
|
@ -1,98 +0,0 @@
|
|||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.resources import Messages
|
||||
from anthropic.types import (
|
||||
ContentBlock,
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageDeltaUsage,
|
||||
MessageParam,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
TextDelta,
|
||||
Usage,
|
||||
)
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockAnthropicClass:
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id="msg-123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
|
||||
model=model,
|
||||
stop_reason="stop_sequence",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
yield MessageStartEvent(
|
||||
type="message_start",
|
||||
message=Message(
|
||||
id="msg-123",
|
||||
content=[],
|
||||
role="assistant",
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type="message",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
),
|
||||
)
|
||||
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
yield ContentBlockDeltaEvent(
|
||||
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type="message_stop")
|
||||
|
||||
def mocked_anthropic(
|
||||
self: Messages,
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
|
||||
else:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
|
||||
raise httpx.HTTPStatusError(
|
||||
"Invalid API key",
|
||||
request=httpx.Request("GET", ""),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"items": [
|
||||
{"title": "Model 1", "_id": "model1"},
|
||||
{"title": "Model 2", "_id": "model2"},
|
||||
]
|
||||
},
|
||||
request=httpx.Request("GET", ""),
|
||||
)
|
||||
|
||||
|
||||
def mock_stream(*args, **kwargs):
|
||||
class MockStreamResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def iter_bytes(self):
|
||||
yield b"Mocked audio data"
|
||||
|
||||
return MockStreamResponse()
|
||||
|
||||
|
||||
def mock_fishaudio(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["list-models", "tts"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock fishaudio module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "list-models" in methods:
|
||||
monkeypatch.setattr(httpx, "get", mock_get)
|
||||
|
||||
if "tts" in methods:
|
||||
monkeypatch.setattr(httpx, "stream", mock_stream)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_fishaudio_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_fishaudio(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ""
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = "it's google!"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
|
||||
|
||||
class MockGoogleResponseCandidateClass:
|
||||
finish_reason = "stop"
|
||||
|
||||
@property
|
||||
def content(self) -> gag_content.Content:
|
||||
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
|
||||
|
||||
|
||||
class MockGoogleClass:
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(
|
||||
self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
safety_settings: safety_types.SafetySettingOptions | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return "it's google!"
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import (
|
||||
Details,
|
||||
StreamDetails,
|
||||
TextGenerationResponse,
|
||||
TextGenerationStreamResponse,
|
||||
Token,
|
||||
)
|
||||
from huggingface_hub.utils import BadRequestError
|
||||
|
||||
|
||||
class MockHuggingfaceChatClass:
|
||||
@staticmethod
|
||||
def generate_create_sync(model: str) -> TextGenerationResponse:
|
||||
response = TextGenerationResponse(
|
||||
generated_text="You can call me Miku Miku o~e~o~",
|
||||
details=Details(
|
||||
finish_reason="length",
|
||||
generated_tokens=6,
|
||||
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
|
||||
full_text = "You can call me Miku Miku o~e~o~"
|
||||
|
||||
for i in range(0, len(full_text)):
|
||||
response = TextGenerationStreamResponse(
|
||||
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
)
|
||||
response.generated_text = full_text[i]
|
||||
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
|
||||
|
||||
yield response
|
||||
|
||||
def text_generation(
|
||||
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
|
||||
# check if key is valid
|
||||
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
|
||||
raise BadRequestError("Invalid API key")
|
||||
|
||||
if model is None:
|
||||
raise BadRequestError("Invalid model")
|
||||
|
||||
if stream:
|
||||
return MockHuggingfaceChatClass.generate_create_stream(model)
|
||||
return MockHuggingfaceChatClass.generate_create_sync(model)
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
|
||||
|
||||
|
||||
class MockTEIClass:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||
if "rerank" in model_name:
|
||||
model_type = "reranker"
|
||||
else:
|
||||
model_type = "embedding"
|
||||
|
||||
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
# Use space as token separator, and split the text into tokens
|
||||
tokenized_texts = []
|
||||
for text in texts:
|
||||
tokens = text.split(" ")
|
||||
current_index = 0
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
s_token = {
|
||||
"id": idx,
|
||||
"text": token,
|
||||
"special": False,
|
||||
"start": current_index,
|
||||
"stop": current_index + len(token),
|
||||
}
|
||||
current_index += len(token) + 1
|
||||
tokenized_text.append(s_token)
|
||||
tokenized_texts.append(tokenized_text)
|
||||
return tokenized_texts
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
# {
|
||||
# "object": "list",
|
||||
# "data": [
|
||||
# {
|
||||
# "object": "embedding",
|
||||
# "embedding": [...],
|
||||
# "index": 0
|
||||
# }
|
||||
# ],
|
||||
# "model": "MODEL_NAME",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 3,
|
||||
# "total_tokens": 3
|
||||
# }
|
||||
# }
|
||||
embeddings = []
|
||||
for idx in range(len(texts)):
|
||||
embedding = [0.1] * 768
|
||||
embeddings.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": idx,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": embeddings,
|
||||
"model": "MODEL_NAME",
|
||||
"usage": {
|
||||
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
"total_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# "index": 0,
|
||||
# "text": "Deep Learning is ...",
|
||||
# "score": 0.9950755
|
||||
# }
|
||||
# ]
|
||||
reranked_docs = []
|
||||
for idx, text in enumerate(texts):
|
||||
reranked_docs.append(
|
||||
{
|
||||
"index": idx,
|
||||
"text": text,
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
# For mock, only return the first document
|
||||
break
|
||||
return reranked_docs
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import pytest
|
||||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from nomic import embed
|
||||
|
||||
|
||||
def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict:
|
||||
texts_len = len(texts)
|
||||
|
||||
foo_embedding_sample = 0.123456
|
||||
|
||||
combined = {
|
||||
"embeddings": [[foo_embedding_sample for _ in range(768)] for _ in range(texts_len)],
|
||||
"usage": {"prompt_tokens": texts_len, "total_tokens": texts_len},
|
||||
"model": model,
|
||||
"inference_mode": "remote",
|
||||
}
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def mock_nomic(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["text_embedding"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock nomic module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "text_embedding" in methods:
|
||||
monkeypatch.setattr(embed, "text", create_embedding)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_nomic_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_nomic(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
|
@ -6,19 +6,9 @@ import pytest
|
|||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai.resources.chat import Completions as ChatCompletions
|
||||
from openai.resources.completions import Completions
|
||||
from openai.resources.embeddings import Embeddings
|
||||
from openai.resources.models import Models
|
||||
from openai.resources.moderations import Moderations
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
|
||||
|
||||
|
||||
def mock_openai(
|
||||
|
|
@ -35,24 +25,9 @@ def mock_openai(
|
|||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "completion" in methods:
|
||||
monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
|
||||
|
||||
if "chat" in methods:
|
||||
monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
|
||||
|
||||
if "remote" in methods:
|
||||
monkeypatch.setattr(Models, "list", MockModelClass.list)
|
||||
|
||||
if "moderation" in methods:
|
||||
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
|
||||
|
||||
if "speech2text" in methods:
|
||||
monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
|
||||
|
||||
if "text_embedding" in methods:
|
||||
monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,269 +0,0 @@
|
|||
import re
|
||||
from collections.abc import Generator
|
||||
from json import dumps, loads
|
||||
from time import time
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.resources.chat.completions import Completions
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaFunctionCall,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockChatClass:
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
) -> Optional[FunctionCall]:
|
||||
if not functions or len(functions) == 0:
|
||||
return None
|
||||
function: completion_create_params.Function = functions[0]
|
||||
function_name = function["name"]
|
||||
function_description = function["description"]
|
||||
function_parameters = function["parameters"]
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool = tools[0]
|
||||
|
||||
if "type" in tools and tools["type"] != "function":
|
||||
return None
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
list_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id="sakurajima-mai",
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
return list_tool_calls
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_sync(
|
||||
model: str,
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> CompletionMessage:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
return _ChatCompletion(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
_ChatCompletionChoice(
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
|
||||
),
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_stream(
|
||||
model: str,
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> Generator[ChatCompletionChunk, None, None]:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield ChatCompletionChunk(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content="",
|
||||
function_call=ChoiceDeltaFunctionCall(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="misaka-mikoto",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_calls[0].function.name,
|
||||
arguments=tool_calls[0].function.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
if tool_calls and len(tool_calls) > 0
|
||||
else None,
|
||||
),
|
||||
finish_reason="function_call",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionChunk(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content=full_text[i],
|
||||
role="assistant",
|
||||
),
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def chat_create(
|
||||
self: Completions,
|
||||
*,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
],
|
||||
],
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
]
|
||||
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if stream:
|
||||
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
import re
|
||||
from collections.abc import Generator
|
||||
from time import time
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from openai import AzureOpenAI, BadRequestError, OpenAI
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.resources.completions import Completions
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai.types.completion import CompletionChoice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockCompletionsClass:
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
|
||||
return CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="mock",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
|
||||
],
|
||||
)
|
||||
|
||||
def completion_create(
|
||||
self: Completions,
|
||||
*,
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
],
|
||||
],
|
||||
prompt: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
]
|
||||
azure_openai_models = ["gpt-35-turbo-instruct"]
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
|
||||
if not prompt:
|
||||
raise BadRequestError("Invalid prompt")
|
||||
if stream:
|
||||
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -1,23 +0,0 @@
|
|||
from time import time
|
||||
|
||||
from openai.resources.models import Models
|
||||
from openai.types.model import Model
|
||||
|
||||
|
||||
class MockModelClass:
|
||||
"""
|
||||
mock class for openai.models.Models
|
||||
"""
|
||||
|
||||
def list(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
|
||||
created=int(time()),
|
||||
object="model",
|
||||
owned_by="organization:org-123",
|
||||
)
|
||||
]
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import re
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from openai._types import NOT_GIVEN, FileTypes, NotGiven
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai.types.audio.transcription import Transcription
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockSpeech2TextClass:
|
||||
def speech2text_create(
|
||||
self: Transcriptions,
|
||||
*,
|
||||
file: FileTypes,
|
||||
model: Union[str, Literal["whisper-1"]],
|
||||
language: str | NotGiven = NOT_GIVEN,
|
||||
prompt: str | NotGiven = NOT_GIVEN,
|
||||
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
|
||||
temperature: float | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
) -> Transcription:
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests import Response
|
||||
from requests.exceptions import ConnectionError
|
||||
from requests.sessions import Session
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
Client,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulEmbeddingModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
RESTfulRerankModelHandle,
|
||||
)
|
||||
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
||||
|
||||
|
||||
class MockXinferenceClass:
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if "generate" == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "chat" == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "embedding" == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "rerank" == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
response = Response()
|
||||
if "v1/models/" in url:
|
||||
# get model uid
|
||||
model_uid = url.split("/")[-1] or ""
|
||||
if not re.match(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
||||
) and model_uid not in {"generate", "chat", "embedding", "rerank"}:
|
||||
response.status_code = 404
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
||||
response.status_code = 404
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
if model_uid in {"generate", "chat"}:
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "chatglm3-6b",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"model_ability": [
|
||||
"generate",
|
||||
"chat"
|
||||
],
|
||||
"model_description": "latest chatglm3",
|
||||
"model_format": "pytorch",
|
||||
"model_size_in_billions": 7,
|
||||
"quantization": "none",
|
||||
"model_hub": "huggingface",
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif model_uid == "embedding":
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "bge",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif "v1/cluster/auth" in url:
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"auth": true
|
||||
}"""
|
||||
return response
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(
|
||||
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "rerank"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
"results": [
|
||||
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "embedding"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
ipt_len = len(input)
|
||||
|
||||
embedding = Embedding(
|
||||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="claude-instant-1.2",
|
||||
credentials={
|
||||
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue