diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 407bd47d9b..6daaaf5791 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -125,7 +125,7 @@ jobs: with: images: ${{ env[matrix.image_name_env] }} tags: | - type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }} + type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }} type=ref,event=branch type=sha,enable=true,priority=100,prefix=,suffix=,format=long type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }} diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 94206a1b1c..897b6fd063 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -231,7 +231,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 93edf8e0e8..798847a507 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -212,7 +212,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8f834b6458..917649f34e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -248,7 +248,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan else: start_listener_time = time.time() yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) - yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + if publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 92da53c9a4..6bd9325785 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -13,7 +14,7 @@ _TEXT_COLOR_MAPPING = { } -class Callback: +class Callback(ABC): """ Base class for callbacks. Only for LLM. @@ -21,6 +22,7 @@ class Callback: raise_error: bool = False + @abstractmethod def on_before_invoke( self, llm_instance: AIModel, @@ -48,6 +50,7 @@ class Callback: """ raise NotImplementedError() + @abstractmethod def on_new_chunk( self, llm_instance: AIModel, @@ -77,6 +80,7 @@ class Callback: """ raise NotImplementedError() + @abstractmethod def on_after_invoke( self, llm_instance: AIModel, @@ -106,6 +110,7 @@ class Callback: """ raise NotImplementedError() + @abstractmethod def on_invoke_error( self, llm_instance: AIModel, diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md new file mode 100644 index 0000000000..f5b806ade6 --- /dev/null +++ b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md @@ -0,0 +1,310 @@ +## Custom Integration of Pre-defined Models + +### Introduction + +After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration. + +It is important to note that for custom models, each model connection requires a complete vendor credential. + +Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file. + +![](images/index/image-3.png) + +As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user. + +### Writing the Vendor YAML + +First, we need to identify the types of models supported by the vendor we are integrating. + +Currently supported model types are as follows: + +- `llm` Text Generation Models + +- `text_embedding` Text Embedding Models + +- `rerank` Rerank Models + +- `speech2text` Speech-to-Text + +- `tts` Text-to-Speech + +- `moderation` Moderation + +Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml. + +```yaml +provider: xinference #Define the vendor identifier +label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default. + en_US: Xorbits Inference +icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label + en_US: icon_s_en.svg +icon_large: # Large icon + en_US: icon_l_en.svg +help: # Help information + title: + en_US: How to deploy Xinference + zh_Hans: 如何部署 Xinference + url: + en_US: https://github.com/xorbitsai/inference +supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank +- llm +- text-embedding +- rerank +configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models. +- customizable-model +provider_credential_schema: + credential_form_schemas: +``` + + +Then, we need to determine what credentials are required to define a model in Xinference. + +- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it: + +```yaml +provider_credential_schema: + credential_form_schemas: + - variable: model_type + type: select + label: + en_US: Model type + zh_Hans: 模型类型 + required: true + options: + - value: text-generation + label: + en_US: Language Model + zh_Hans: 语言模型 + - value: embeddings + label: + en_US: Text Embedding + - value: reranking + label: + en_US: Rerank +``` + +- Next, each model has its own model_name, so we need to define that here: + +```yaml + - variable: model_name + type: text-input + label: + en_US: Model name + zh_Hans: 模型名称 + required: true + placeholder: + zh_Hans: 填写模型名称 + en_US: Input model name +``` + +- Specify the Xinference local deployment address: + +```yaml + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: text-input + required: true + placeholder: + zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx + en_US: Enter the url of your Xinference, for example https://example.com/xxx +``` + +- Each model has a unique model_uid, so we also need to define that here: + +```yaml + - variable: model_uid + label: + zh_Hans: 模型UID + en_US: Model uid + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的Model UID + en_US: Enter the model uid +``` + +Now, we have completed the basic definition of the vendor. + +### Writing the Model Code + +Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`. + +In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods: + +- LLM Invocation + +Implement the core method for LLM invocation, supporting both stream and synchronous responses. + +```python +def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool usage + :param stop: stop words + :param stream: is the response a stream + :param user: unique user id + :return: full response or stream response chunk generator result + """ +``` + +When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above): + +```python +def _invoke(self, stream: bool, **kwargs) \ + -> Union[LLMResult, Generator]: + if stream: + return self._handle_stream_response(**kwargs) + return self._handle_sync_response(**kwargs) + +def _handle_stream_response(self, **kwargs) -> Generator: + for chunk in response: + yield chunk +def _handle_sync_response(self, **kwargs) -> LLMResult: + return LLMResult(**response) +``` + +- Pre-compute Input Tokens + +If the model does not provide an interface for pre-computing tokens, you can return 0 directly. + +```python +def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool usage + :return: token count + """ +``` + + +Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate. + +- Model Credentials Validation + +Similar to vendor credentials validation, this method validates individual model credentials. + +```python +def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: None + """ +``` + +- Model Parameter Schema + +Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema. + +For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters. + +However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below: + +```python + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', en_US='Temperature' + ) + ), + ParameterRule( + name='top_p', type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', type=ParameterType.INT, + use_template='max_tokens', + min=1, + default=512, + label=I18nObject( + zh_Hans='最大生成长度', en_US='Max Tokens' + ) + ) + ] + + # if model is A, add top_k to rules + if model == 'A': + rules.append( + ParameterRule( + name='top_k', type=ParameterType.INT, + use_template='top_k', + min=1, + default=50, + label=I18nObject( + zh_Hans='Top K', en_US='Top K' + ) + ) + ) + + """ + some NOT IMPORTANT code here + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=model_type, + model_properties={ + ModelPropertyKey.MODE: ModelType.LLM, + }, + parameter_rules=rules + ) + + return entity +``` + +- Exception Error Mapping + +When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately. + +Runtime Errors: + +- `InvokeConnectionError` Connection error during invocation +- `InvokeServerUnavailableError` Service provider unavailable +- `InvokeRateLimitError` Rate limit reached +- `InvokeAuthorizationError` Authorization failure +- `InvokeBadRequestError` Invalid request parameters + +```python + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ +``` + +For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py). \ No newline at end of file diff --git a/api/core/model_runtime/docs/en_US/images/index/image-1.png b/api/core/model_runtime/docs/en_US/images/index/image-1.png new file mode 100644 index 0000000000..b158d44b29 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-1.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image-2.png b/api/core/model_runtime/docs/en_US/images/index/image-2.png new file mode 100644 index 0000000000..c70cd3da5e Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-2.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image-3.png b/api/core/model_runtime/docs/en_US/images/index/image-3.png new file mode 100644 index 0000000000..bf0b9a7f47 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-3.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image.png b/api/core/model_runtime/docs/en_US/images/index/image.png new file mode 100644 index 0000000000..eb63d107e1 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image.png differ diff --git a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md new file mode 100644 index 0000000000..3e16257452 --- /dev/null +++ b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md @@ -0,0 +1,173 @@ +## Predefined Model Integration + +After completing the vendor integration, the next step is to integrate the models from the vendor. + +First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory. + +Currently supported model types are: + +- `llm` Text Generation Model +- `text_embedding` Text Embedding Model +- `rerank` Rerank Model +- `speech2text` Speech-to-Text +- `tts` Text-to-Speech +- `moderation` Moderation + +Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`. + +For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`. + +### Prepare Model YAML + +```yaml +model: claude-2.1 # Model identifier +# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US. +# This can also be omitted, in which case the model identifier will be used as the label +label: + en_US: claude-2.1 +model_type: llm # Model type, claude-2.1 is an LLM +features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding +- agent-thought +model_properties: # Model properties + mode: chat # LLM mode, complete for text completion models, chat for conversation models + context_size: 200000 # Maximum context size +parameter_rules: # Parameter rules for the model call; only LLM requires this +- name: temperature # Parameter variable name + # Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty + # The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE + # Additional configuration parameters will override the default configuration if set + use_template: temperature +- name: top_p + use_template: top_p +- name: top_k + label: # Display name of the parameter + zh_Hans: 取样数量 + en_US: Top k + type: int # Parameter type, supports float/int/string/boolean + help: # Help information, describing the parameter's function + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false # Whether the parameter is mandatory; can be omitted +- name: max_tokens_to_sample + use_template: max_tokens + default: 4096 # Default value of the parameter + min: 1 # Minimum value of the parameter, applicable to float/int only + max: 4096 # Maximum value of the parameter, applicable to float/int only +pricing: # Pricing information + input: '8.00' # Input unit price, i.e., prompt price + output: '24.00' # Output unit price, i.e., response content price + unit: '0.000001' # Price unit, meaning the above prices are per 100K + currency: USD # Price currency +``` + +It is recommended to prepare all model configurations before starting the implementation of the model code. + +You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity). + +### Implement the Model Call Code + +Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code. + +Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods: + +- LLM Call + +Implement the core method for calling the LLM, supporting both streaming and synchronous responses. + +```python + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ +``` + +Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list): + +```python + def _invoke(self, stream: bool, **kwargs) \ + -> Union[LLMResult, Generator]: + if stream: + return self._handle_stream_response(**kwargs) + return self._handle_sync_response(**kwargs) + + def _handle_stream_response(self, **kwargs) -> Generator: + for chunk in response: + yield chunk + def _handle_sync_response(self, **kwargs) -> LLMResult: + return LLMResult(**response) +``` + +- Pre-compute Input Tokens + +If the model does not provide an interface to precompute tokens, return 0 directly. + +```python + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ +``` + +- Validate Model Credentials + +Similar to vendor credential validation, but specific to a single model. + +```python + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ +``` + +- Map Invoke Errors + +When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly. + +Runtime Errors: + +- `InvokeConnectionError` Connection error + +- `InvokeServerUnavailableError` Service provider unavailable +- `InvokeRateLimitError` Rate limit reached +- `InvokeAuthorizationError` Authorization failed +- `InvokeBadRequestError` Parameter error + +```python + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ +``` + +For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py). \ No newline at end of file diff --git a/api/core/model_runtime/docs/en_US/provider_scale_out.md b/api/core/model_runtime/docs/en_US/provider_scale_out.md index ba356c5cab..07be5811d3 100644 --- a/api/core/model_runtime/docs/en_US/provider_scale_out.md +++ b/api/core/model_runtime/docs/en_US/provider_scale_out.md @@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp en_US: Enter your API URL ``` -You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider). +You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider). ### Implementing Provider Code diff --git a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md index b34544c789..78aad8876f 100644 --- a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md @@ -117,7 +117,7 @@ model_credential_schema: en_US: Enter your API Base ``` -也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。 +也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。 #### 实现供应商代码 diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 80db22ea84..89fccef659 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -40,3 +40,4 @@ - fireworks - mixedbread - nomic +- voyage diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py deleted file mode 100644 index d9c5726592..0000000000 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ /dev/null @@ -1,238 +0,0 @@ -import json -import logging -import time -from typing import Optional - -import boto3 -from botocore.config import Config -from botocore.exceptions import ( - ClientError, - EndpointConnectionError, - NoRegionError, - ServiceNotInRegionError, - UnknownServiceError, -) - -from core.embedding.embedding_constant import EmbeddingInputType -from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel - -logger = logging.getLogger(__name__) - - -class BedrockTextEmbeddingModel(TextEmbeddingModel): - def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :param input_type: input type - :return: embeddings result - """ - client_config = Config(region_name=credentials["aws_region"]) - - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) - - embeddings = [] - token_usage = 0 - - model_prefix = model.split(".")[0] - - if model_prefix == "amazon": - for text in texts: - body = { - "inputText": text, - } - response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get("embedding")]) - token_usage += response_body.get("inputTextTokenCount") - logger.warning(f"Total Tokens: {token_usage}") - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), - ) - return result - - if model_prefix == "cohere": - input_type = "search_document" if len(texts) > 1 else "search_query" - for text in texts: - body = { - "texts": [text], - "input_type": input_type, - } - response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get("embeddings")) - token_usage += len(text) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), - ) - return result - - # others - raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - num_tokens = 0 - for text in texts: - num_tokens += self._get_num_tokens_by_gpt2(text) - return num_tokens - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller - The value is the md = genai.GenerativeModel(model) error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke emd = genai.GenerativeModel(model) error mapping - """ - return { - InvokeConnectionError: [], - InvokeServerUnavailableError: [], - InvokeRateLimitError: [], - InvokeAuthorizationError: [], - InvokeBadRequestError: [], - } - - def _create_payload( - self, - model_prefix: str, - texts: list[str], - model_parameters: dict, - stop: Optional[list[str]] = None, - stream: bool = True, - ): - """ - Create payload for bedrock api call depending on model provider - """ - payload = {} - - if model_prefix == "amazon": - payload["inputText"] = texts - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param tokens: input tokens - :return: usage - """ - # get input price info - input_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens - ) - - # transform usage - usage = EmbeddingUsage( - tokens=tokens, - total_tokens=tokens, - unit_price=input_price_info.unit_price, - price_unit=input_price_info.unit, - total_price=input_price_info.total_amount, - currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: - """ - Map client error to invoke error - - :param error_code: error code - :param error_msg: error message - :return: invoke error - """ - - if error_code == "AccessDeniedException": - return InvokeAuthorizationError(error_msg) - elif error_code in {"ResourceNotFoundException", "ValidationException"}: - return InvokeBadRequestError(error_msg) - elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: - return InvokeRateLimitError(error_msg) - elif error_code in { - "ModelTimeoutException", - "ModelErrorException", - "InternalServerException", - "ModelNotReadyException", - }: - return InvokeServerUnavailableError(error_msg) - elif error_code == "ModelStreamErrorException": - return InvokeConnectionError(error_msg) - - return InvokeError(error_msg) - - def _invoke_bedrock_embedding( - self, - model: str, - bedrock_runtime, - body: dict, - ): - accept = "application/json" - content_type = "application/json" - try: - response = bedrock_runtime.invoke_model( - body=json.dumps(body), modelId=model, accept=accept, contentType=content_type - ) - response_body = json.loads(response.get("body").read().decode("utf-8")) - return response_body - except ClientError as ex: - error_code = ex.response["Error"]["Code"] - full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" - raise self._map_client_to_invoke_error(error_code, full_error_msg) - - except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: - raise InvokeConnectionError(str(ex)) - - except UnknownServiceError as ex: - raise InvokeServerUnavailableError(str(ex)) - - except Exception as ex: - raise InvokeError(str(ex)) diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py deleted file mode 100644 index 1e86f351c8..0000000000 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ /dev/null @@ -1,223 +0,0 @@ -import json -import time -from decimal import Decimal -from typing import Optional -from urllib.parse import urljoin - -import numpy as np -import requests - -from core.embedding.embedding_constant import EmbeddingInputType -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelPropertyKey, - ModelType, - PriceConfig, - PriceType, -) -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat - - -class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): - """ - Model class for an OpenAI API-compatible text embedding model. - """ - - def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :param input_type: input type - :return: embeddings result - """ - - # Prepare headers and payload for the request - headers = {"Content-Type": "application/json"} - - api_key = credentials.get("api_key") - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - endpoint_url = "https://cloud.perfxlab.cn/v1/" - else: - endpoint_url = credentials.get("endpoint_url") - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - endpoint_url = urljoin(endpoint_url, "embeddings") - - extra_model_kwargs = {} - if user: - extra_model_kwargs["user"] = user - - extra_model_kwargs["encoding_format"] = "float" - - # get model properties - context_size = self._get_context_size(model, credentials) - max_chunks = self._get_max_chunks(model, credentials) - - inputs = [] - indices = [] - used_tokens = 0 - - for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer - # TODO: Optimize for better token estimation and chunking - num_tokens = self._get_num_tokens_by_gpt2(text) - - if num_tokens >= context_size: - cutoff = int(np.floor(len(text) * (context_size / num_tokens))) - # if num tokens is larger than context length, only use the start - inputs.append(text[0:cutoff]) - else: - inputs.append(text) - indices += [i] - - batched_embeddings = [] - _iter = range(0, len(inputs), max_chunks) - - for i in _iter: - # Prepare the payload for the request - payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} - - # Make the request to the OpenAI API - response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) - - response.raise_for_status() # Raise an exception for HTTP errors - response_data = response.json() - - # Extract embeddings and used tokens from the response - embeddings_batch = [data["embedding"] for data in response_data["data"]] - embedding_used_tokens = response_data["usage"]["total_tokens"] - - used_tokens += embedding_used_tokens - batched_embeddings += embeddings_batch - - # calc usage - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - - return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Approximate number of tokens for given messages using GPT2 tokenizer - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return sum(self._get_num_tokens_by_gpt2(text) for text in texts) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - headers = {"Content-Type": "application/json"} - - api_key = credentials.get("api_key") - - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - endpoint_url = "https://cloud.perfxlab.cn/v1/" - else: - endpoint_url = credentials.get("endpoint_url") - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - endpoint_url = urljoin(endpoint_url, "embeddings") - - payload = {"input": "ping", "model": model} - - response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) - - if response.status_code != 200: - raise CredentialsValidateFailedError( - f"Credentials validation failed with status code {response.status_code}" - ) - - try: - json_result = response.json() - except json.JSONDecodeError as e: - raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - - if "model" not in json_result: - raise CredentialsValidateFailedError("Credentials validation failed: invalid response") - except CredentialsValidateFailedError: - raise - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - """ - generate custom model entities from credentials - """ - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=ModelType.TEXT_EMBEDDING, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), - ModelPropertyKey.MAX_CHUNKS: 1, - }, - parameter_rules=[], - pricing=PriceConfig( - input=Decimal(credentials.get("input_price", 0)), - unit=Decimal(credentials.get("unit", 0)), - currency=credentials.get("currency", "USD"), - ), - ) - - return entity - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param tokens: input tokens - :return: usage - """ - # get input price info - input_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens - ) - - # transform usage - usage = EmbeddingUsage( - tokens=tokens, - total_tokens=tokens, - unit_price=input_price_info.unit_price, - price_unit=input_price_info.unit, - total_price=input_price_info.total_amount, - currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage diff --git a/api/core/model_runtime/model_providers/voyage/__init__.py b/api/core/model_runtime/model_providers/voyage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg new file mode 100644 index 0000000000..a961f5e435 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg @@ -0,0 +1,21 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg new file mode 100644 index 0000000000..2c4e121dd7 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + voyage + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/rerank/__init__.py b/api/core/model_runtime/model_providers/voyage/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml new file mode 100644 index 0000000000..9c894eda85 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml @@ -0,0 +1,4 @@ +model: rerank-1 +model_type: rerank +model_properties: + context_size: 8000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml new file mode 100644 index 0000000000..b052d6f000 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml @@ -0,0 +1,4 @@ +model: rerank-lite-1 +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank.py b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py new file mode 100644 index 0000000000..33fdebbb45 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py @@ -0,0 +1,123 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class VoyageRerankModel(RerankModel): + """ + Model class for Voyage rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://api.voyageai.com/v1") + base_url = base_url.removesuffix("/") + + try: + response = httpx.post( + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["data"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py b/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py new file mode 100644 index 0000000000..a8a4d3c15b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -0,0 +1,172 @@ +import time +from json import JSONDecodeError, dumps +from typing import Optional + +import requests + +from core.embedding.embedding_constant import EmbeddingInputType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class VoyageTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Voyage text embedding model. + """ + + api_base: str = "https://api.voyageai.com/v1" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + api_key = credentials["api_key"] + if not api_key: + raise CredentialsValidateFailedError("api_key is required") + + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") + + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} + voyage_input_type = "null" + if input_type is not None: + voyage_input_type = input_type.value + data = {"model": model, "input": texts, "input_type": voyage_input_type} + + try: + response = requests.post(url, headers=headers, data=dumps(data)) + except Exception as e: + raise InvokeConnectionError(str(e)) + + if response.status_code != 200: + try: + resp = response.json() + msg = resp["detail"] + if response.status_code == 401: + raise InvokeAuthorizationError(msg) + elif response.status_code == 429: + raise InvokeRateLimitError(msg) + elif response.status_code == 500: + raise InvokeServerUnavailableError(msg) + else: + raise InvokeBadRequestError(msg) + except JSONDecodeError as e: + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) + + try: + resp = response.json() + embeddings = resp["data"] + usage = resp["usage"] + except Exception as e: + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) + + result = TextEmbeddingResult( + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as e: + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml new file mode 100644 index 0000000000..a06bb7639f --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml @@ -0,0 +1,8 @@ +model: voyage-3-lite +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00002' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml new file mode 100644 index 0000000000..117afbcaf3 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml @@ -0,0 +1,8 @@ +model: voyage-3 +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/voyage.py b/api/core/model_runtime/model_providers/voyage/voyage.py new file mode 100644 index 0000000000..3e33b45e11 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VoyageProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) + + # Use `voyage-3` model for validate, + # no matter what model you pass in, text completion model or chat model + model_instance.validate_credentials(model="voyage-3", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/voyage/voyage.yaml b/api/core/model_runtime/model_providers/voyage/voyage.yaml new file mode 100644 index 0000000000..c64707800e --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.yaml @@ -0,0 +1,31 @@ +provider: voyage +label: + en_US: Voyage +description: + en_US: Embedding and Rerank Model Supported +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#EFFDFD" +help: + title: + en_US: Get your API key from Voyage AI + zh_Hans: 从 Voyage 获取 API Key + url: + en_US: https://dash.voyageai.com/ +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py deleted file mode 100644 index 14a529dddf..0000000000 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ /dev/null @@ -1,142 +0,0 @@ -import time -from typing import Optional - -from core.embedding.embedding_constant import EmbeddingInputType -from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI - - -class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): - """ - Model class for ZhipuAI text embedding model. - """ - - def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :param input_type: input type - :return: embeddings result - """ - credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(api_key=credentials_kwargs["api_key"]) - - embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) - - return TextEmbeddingResult( - embeddings=embeddings, - usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model, - ) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - if len(texts) == 0: - return 0 - - total_num_tokens = 0 - for text in texts: - total_num_tokens += self._get_num_tokens_by_gpt2(text) - - return total_num_tokens - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - # transform credentials to kwargs for model instance - credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(api_key=credentials_kwargs["api_key"]) - - # call embedding model - self.embed_documents( - model=model, - client=client, - texts=["ping"], - ) - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]: - """Call out to ZhipuAI's embedding endpoint. - - Args: - texts: The list of texts to embed. - - Returns: - List of embeddings, one for each text. - """ - embeddings = [] - embedding_used_tokens = 0 - - for text in texts: - response = client.embeddings.create(model=model, input=text) - data = response.data[0] - embeddings.append(data.embedding) - embedding_used_tokens += response.usage.total_tokens - - return [list(map(float, e)) for e in embeddings], embedding_used_tokens - - def embed_query(self, text: str) -> list[float]: - """Call out to ZhipuAI's embedding endpoint. - - Args: - text: The text to embed. - - Returns: - Embeddings for the text. - """ - return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: - """ - Calculate response usage - - :param model: model name - :param tokens: input tokens - :return: usage - """ - # get input price info - input_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens - ) - - # transform usage - usage = EmbeddingUsage( - tokens=tokens, - total_tokens=tokens, - unit_price=input_price_info.unit_price, - price_unit=input_price_info.unit, - total_price=input_price_info.total_amount, - currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 3073100746..a0153c1e58 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -45,7 +45,7 @@ class Jieba(BaseKeyword): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get("keywords_list", None) + keywords_list = kwargs.get("keywords_list") for i in range(len(texts)): text = texts[i] if keywords_list: diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 54f6a76e16..5af45e1e50 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ from models.dataset import Document @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids", None) + document_ids = kwargs.get("document_ids") documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/poetry.lock b/api/poetry.lock index bce21fb547..85c68cd75f 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -8074,29 +8074,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.6.5" +version = "0.6.8" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"}, - {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"}, - {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"}, - {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"}, - {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"}, - {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"}, - {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"}, - {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"}, - {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"}, - {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"}, - {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"}, - {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"}, + {file = "ruff-0.6.8-py3-none-linux_armv6l.whl", hash = "sha256:77944bca110ff0a43b768f05a529fecd0706aac7bcce36d7f1eeb4cbfca5f0f2"}, + {file = "ruff-0.6.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27b87e1801e786cd6ede4ada3faa5e254ce774de835e6723fd94551464c56b8c"}, + {file = "ruff-0.6.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd48f945da2a6334f1793d7f701725a76ba93bf3d73c36f6b21fb04d5338dcf5"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:677e03c00f37c66cea033274295a983c7c546edea5043d0c798833adf4cf4c6f"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9f1476236b3eacfacfc0f66aa9e6cd39f2a624cb73ea99189556015f27c0bdeb"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f5a2f17c7d32991169195d52a04c95b256378bbf0de8cb98478351eb70d526f"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5fd0d4b7b1457c49e435ee1e437900ced9b35cb8dc5178921dfb7d98d65a08d0"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8034b19b993e9601f2ddf2c517451e17a6ab5cdb1c13fdff50c1442a7171d87"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cfb227b932ba8ef6e56c9f875d987973cd5e35bc5d05f5abf045af78ad8e098"}, + {file = "ruff-0.6.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef0411eccfc3909269fed47c61ffebdcb84a04504bafa6b6df9b85c27e813b0"}, + {file = "ruff-0.6.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:007dee844738c3d2e6c24ab5bc7d43c99ba3e1943bd2d95d598582e9c1b27750"}, + {file = "ruff-0.6.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ce60058d3cdd8490e5e5471ef086b3f1e90ab872b548814e35930e21d848c9ce"}, + {file = "ruff-0.6.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1085c455d1b3fdb8021ad534379c60353b81ba079712bce7a900e834859182fa"}, + {file = "ruff-0.6.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:70edf6a93b19481affd287d696d9e311388d808671bc209fb8907b46a8c3af44"}, + {file = "ruff-0.6.8-py3-none-win32.whl", hash = "sha256:792213f7be25316f9b46b854df80a77e0da87ec66691e8f012f887b4a671ab5a"}, + {file = "ruff-0.6.8-py3-none-win_amd64.whl", hash = "sha256:ec0517dc0f37cad14a5319ba7bba6e7e339d03fbf967a6d69b0907d61be7a263"}, + {file = "ruff-0.6.8-py3-none-win_arm64.whl", hash = "sha256:8d3bb2e3fbb9875172119021a13eed38849e762499e3cfde9588e4b4d70968dc"}, + {file = "ruff-0.6.8.tar.gz", hash = "sha256:a5bf44b1aa0adaf6d9d20f86162b34f7c593bfedabc51239953e446aefc8ce18"}, ] [[package]] @@ -10501,4 +10501,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "69b42bb1ff033f14e199fee8335356275099421d72bbd7037b7a991ea65cae08" +content-hash = "c4580c22e2b220c8c80dbc3f765060a09e14874ed29b690c13a533bf0365e789" diff --git a/api/pyproject.toml b/api/pyproject.toml index f004865d5f..e737761f3b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-" TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa" +VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa" [tool.poetry] name = "dify-api" @@ -286,4 +287,4 @@ optional = true [tool.poetry.group.lint.dependencies] dotenv-linter = "~0.5.0" -ruff = "~0.6.5" +ruff = "~0.6.8" diff --git a/api/tests/integration_tests/model_runtime/voyage/__init__.py b/api/tests/integration_tests/model_runtime/voyage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_provider.py b/api/tests/integration_tests/model_runtime/voyage/test_provider.py new file mode 100644 index 0000000000..08978c88a9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_provider.py @@ -0,0 +1,25 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.voyage import VoyageProvider + + +def test_validate_provider_credentials(): + provider = VoyageProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/voyage/test_rerank.py b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py new file mode 100644 index 0000000000..e97a9e4c81 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py @@ -0,0 +1,92 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel + + +def test_validate_credentials(): + model = VoyageRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="rerank-lite-1", + credentials={"api_key": "invalid_key"}, + ) + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.546875, + "index": 0, + "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United " + "States Census, Carson City had a population of 55,274.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the " + "Pacific Ocean that are a political division controlled by the United States. Its " + "capital is Saipan.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 96}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = VoyageRerankModel() + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.84375, + "index": 0, + "document": "Kasumi is a girl name of Japanese origin meaning mist.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she " + "leads a team named PopiParty.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 59}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl name of Japanese origin meaning mist.", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named " + "PopiParty.", + ], + score_threshold=0.5, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.5 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py new file mode 100644 index 0000000000..75719672a9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py @@ -0,0 +1,70 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel + + +def test_validate_credentials(): + model = VoyageTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) + + +def test_invoke_model(): + model = VoyageTextEmbeddingModel() + + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}, + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1}, + ], + "model": "voyage-3", + "usage": {"total_tokens": 2}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = VoyageTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["ping"], + ) + + assert num_tokens == 1 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index d3c1f3101c..42cf87e317 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -17,7 +17,7 @@ class MockedHttp: request = httpx.Request( method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") ) - data = kwargs.get("data", None) + data = kwargs.get("data") resp = json.dumps(data).encode("utf-8") if data else b"OK" response = httpx.Response( status_code=200, diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index f1ab23b002..ec013183b7 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -22,8 +22,8 @@ class MockedHttp: return response # get data, files - data = kwargs.get("data", None) - files = kwargs.get("files", None) + data = kwargs.get("data") + files = kwargs.get("files") if data is not None: resp = dumps(data).encode("utf-8") elif files is not None: diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh index b60ff64fdc..63891eb9f8 100755 --- a/dev/pytest/pytest_model_runtime.sh +++ b/dev/pytest/pytest_model_runtime.sh @@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \ api/tests/integration_tests/model_runtime/upstage \ api/tests/integration_tests/model_runtime/fireworks \ api/tests/integration_tests/model_runtime/nomic \ - api/tests/integration_tests/model_runtime/mixedbread + api/tests/integration_tests/model_runtime/mixedbread \ + api/tests/integration_tests/model_runtime/voyage \ No newline at end of file