diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index 407bd47d9b..6daaaf5791 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -125,7 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
- type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
+ type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 94206a1b1c..897b6fd063 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -231,7 +231,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
- yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
+ if tts_publisher:
+ yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index 93edf8e0e8..798847a507 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -212,7 +212,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
- yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
+ if tts_publisher:
+ yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 8f834b6458..917649f34e 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -248,7 +248,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
- yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
+ if publisher:
+ yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py
index 92da53c9a4..6bd9325785 100644
--- a/api/core/model_runtime/callbacks/base_callback.py
+++ b/api/core/model_runtime/callbacks/base_callback.py
@@ -1,3 +1,4 @@
+from abc import ABC, abstractmethod
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -13,7 +14,7 @@ _TEXT_COLOR_MAPPING = {
}
-class Callback:
+class Callback(ABC):
"""
Base class for callbacks.
Only for LLM.
@@ -21,6 +22,7 @@ class Callback:
raise_error: bool = False
+ @abstractmethod
def on_before_invoke(
self,
llm_instance: AIModel,
@@ -48,6 +50,7 @@ class Callback:
"""
raise NotImplementedError()
+ @abstractmethod
def on_new_chunk(
self,
llm_instance: AIModel,
@@ -77,6 +80,7 @@ class Callback:
"""
raise NotImplementedError()
+ @abstractmethod
def on_after_invoke(
self,
llm_instance: AIModel,
@@ -106,6 +110,7 @@ class Callback:
"""
raise NotImplementedError()
+ @abstractmethod
def on_invoke_error(
self,
llm_instance: AIModel,
diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
new file mode 100644
index 0000000000..f5b806ade6
--- /dev/null
+++ b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
@@ -0,0 +1,310 @@
+## Custom Integration of Pre-defined Models
+
+### Introduction
+
+After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
+
+It is important to note that for custom models, each model connection requires a complete vendor credential.
+
+Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
+
+
+
+As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
+
+### Writing the Vendor YAML
+
+First, we need to identify the types of models supported by the vendor we are integrating.
+
+Currently supported model types are as follows:
+
+- `llm` Text Generation Models
+
+- `text_embedding` Text Embedding Models
+
+- `rerank` Rerank Models
+
+- `speech2text` Speech-to-Text
+
+- `tts` Text-to-Speech
+
+- `moderation` Moderation
+
+Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
+
+```yaml
+provider: xinference #Define the vendor identifier
+label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
+ en_US: Xorbits Inference
+icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
+ en_US: icon_s_en.svg
+icon_large: # Large icon
+ en_US: icon_l_en.svg
+help: # Help information
+ title:
+ en_US: How to deploy Xinference
+ zh_Hans: 如何部署 Xinference
+ url:
+ en_US: https://github.com/xorbitsai/inference
+supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
+- llm
+- text-embedding
+- rerank
+configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
+- customizable-model
+provider_credential_schema:
+ credential_form_schemas:
+```
+
+
+Then, we need to determine what credentials are required to define a model in Xinference.
+
+- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
+
+```yaml
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: model_type
+ type: select
+ label:
+ en_US: Model type
+ zh_Hans: 模型类型
+ required: true
+ options:
+ - value: text-generation
+ label:
+ en_US: Language Model
+ zh_Hans: 语言模型
+ - value: embeddings
+ label:
+ en_US: Text Embedding
+ - value: reranking
+ label:
+ en_US: Rerank
+```
+
+- Next, each model has its own model_name, so we need to define that here:
+
+```yaml
+ - variable: model_name
+ type: text-input
+ label:
+ en_US: Model name
+ zh_Hans: 模型名称
+ required: true
+ placeholder:
+ zh_Hans: 填写模型名称
+ en_US: Input model name
+```
+
+- Specify the Xinference local deployment address:
+
+```yaml
+ - variable: server_url
+ label:
+ zh_Hans: 服务器URL
+ en_US: Server url
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
+ en_US: Enter the url of your Xinference, for example https://example.com/xxx
+```
+
+- Each model has a unique model_uid, so we also need to define that here:
+
+```yaml
+ - variable: model_uid
+ label:
+ zh_Hans: 模型UID
+ en_US: Model uid
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的Model UID
+ en_US: Enter the model uid
+```
+
+Now, we have completed the basic definition of the vendor.
+
+### Writing the Model Code
+
+Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
+
+In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
+
+- LLM Invocation
+
+Implement the core method for LLM invocation, supporting both stream and synchronous responses.
+
+```python
+def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool usage
+ :param stop: stop words
+ :param stream: is the response a stream
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+```
+
+When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
+
+```python
+def _invoke(self, stream: bool, **kwargs) \
+ -> Union[LLMResult, Generator]:
+ if stream:
+ return self._handle_stream_response(**kwargs)
+ return self._handle_sync_response(**kwargs)
+
+def _handle_stream_response(self, **kwargs) -> Generator:
+ for chunk in response:
+ yield chunk
+def _handle_sync_response(self, **kwargs) -> LLMResult:
+ return LLMResult(**response)
+```
+
+- Pre-compute Input Tokens
+
+If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
+
+```python
+def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool usage
+ :return: token count
+ """
+```
+
+
+Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
+
+- Model Credentials Validation
+
+Similar to vendor credentials validation, this method validates individual model credentials.
+
+```python
+def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return: None
+ """
+```
+
+- Model Parameter Schema
+
+Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
+
+For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
+
+However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
+
+```python
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ rules = [
+ ParameterRule(
+ name='temperature', type=ParameterType.FLOAT,
+ use_template='temperature',
+ label=I18nObject(
+ zh_Hans='温度', en_US='Temperature'
+ )
+ ),
+ ParameterRule(
+ name='top_p', type=ParameterType.FLOAT,
+ use_template='top_p',
+ label=I18nObject(
+ zh_Hans='Top P', en_US='Top P'
+ )
+ ),
+ ParameterRule(
+ name='max_tokens', type=ParameterType.INT,
+ use_template='max_tokens',
+ min=1,
+ default=512,
+ label=I18nObject(
+ zh_Hans='最大生成长度', en_US='Max Tokens'
+ )
+ )
+ ]
+
+ # if model is A, add top_k to rules
+ if model == 'A':
+ rules.append(
+ ParameterRule(
+ name='top_k', type=ParameterType.INT,
+ use_template='top_k',
+ min=1,
+ default=50,
+ label=I18nObject(
+ zh_Hans='Top K', en_US='Top K'
+ )
+ )
+ )
+
+ """
+ some NOT IMPORTANT code here
+ """
+
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=model_type,
+ model_properties={
+ ModelPropertyKey.MODE: ModelType.LLM,
+ },
+ parameter_rules=rules
+ )
+
+ return entity
+```
+
+- Exception Error Mapping
+
+When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
+
+Runtime Errors:
+
+- `InvokeConnectionError` Connection error during invocation
+- `InvokeServerUnavailableError` Service provider unavailable
+- `InvokeRateLimitError` Rate limit reached
+- `InvokeAuthorizationError` Authorization failure
+- `InvokeBadRequestError` Invalid request parameters
+
+```python
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+```
+
+For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
\ No newline at end of file
diff --git a/api/core/model_runtime/docs/en_US/images/index/image-1.png b/api/core/model_runtime/docs/en_US/images/index/image-1.png
new file mode 100644
index 0000000000..b158d44b29
Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-1.png differ
diff --git a/api/core/model_runtime/docs/en_US/images/index/image-2.png b/api/core/model_runtime/docs/en_US/images/index/image-2.png
new file mode 100644
index 0000000000..c70cd3da5e
Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-2.png differ
diff --git a/api/core/model_runtime/docs/en_US/images/index/image-3.png b/api/core/model_runtime/docs/en_US/images/index/image-3.png
new file mode 100644
index 0000000000..bf0b9a7f47
Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-3.png differ
diff --git a/api/core/model_runtime/docs/en_US/images/index/image.png b/api/core/model_runtime/docs/en_US/images/index/image.png
new file mode 100644
index 0000000000..eb63d107e1
Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image.png differ
diff --git a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
new file mode 100644
index 0000000000..3e16257452
--- /dev/null
+++ b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
@@ -0,0 +1,173 @@
+## Predefined Model Integration
+
+After completing the vendor integration, the next step is to integrate the models from the vendor.
+
+First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
+
+Currently supported model types are:
+
+- `llm` Text Generation Model
+- `text_embedding` Text Embedding Model
+- `rerank` Rerank Model
+- `speech2text` Speech-to-Text
+- `tts` Text-to-Speech
+- `moderation` Moderation
+
+Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
+
+For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
+
+### Prepare Model YAML
+
+```yaml
+model: claude-2.1 # Model identifier
+# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
+# This can also be omitted, in which case the model identifier will be used as the label
+label:
+ en_US: claude-2.1
+model_type: llm # Model type, claude-2.1 is an LLM
+features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
+- agent-thought
+model_properties: # Model properties
+ mode: chat # LLM mode, complete for text completion models, chat for conversation models
+ context_size: 200000 # Maximum context size
+parameter_rules: # Parameter rules for the model call; only LLM requires this
+- name: temperature # Parameter variable name
+ # Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
+ # The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
+ # Additional configuration parameters will override the default configuration if set
+ use_template: temperature
+- name: top_p
+ use_template: top_p
+- name: top_k
+ label: # Display name of the parameter
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int # Parameter type, supports float/int/string/boolean
+ help: # Help information, describing the parameter's function
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ required: false # Whether the parameter is mandatory; can be omitted
+- name: max_tokens_to_sample
+ use_template: max_tokens
+ default: 4096 # Default value of the parameter
+ min: 1 # Minimum value of the parameter, applicable to float/int only
+ max: 4096 # Maximum value of the parameter, applicable to float/int only
+pricing: # Pricing information
+ input: '8.00' # Input unit price, i.e., prompt price
+ output: '24.00' # Output unit price, i.e., response content price
+ unit: '0.000001' # Price unit, meaning the above prices are per 100K
+ currency: USD # Price currency
+```
+
+It is recommended to prepare all model configurations before starting the implementation of the model code.
+
+You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
+
+### Implement the Model Call Code
+
+Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
+
+Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
+
+- LLM Call
+
+Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
+
+```python
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+```
+
+Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
+
+```python
+ def _invoke(self, stream: bool, **kwargs) \
+ -> Union[LLMResult, Generator]:
+ if stream:
+ return self._handle_stream_response(**kwargs)
+ return self._handle_sync_response(**kwargs)
+
+ def _handle_stream_response(self, **kwargs) -> Generator:
+ for chunk in response:
+ yield chunk
+ def _handle_sync_response(self, **kwargs) -> LLMResult:
+ return LLMResult(**response)
+```
+
+- Pre-compute Input Tokens
+
+If the model does not provide an interface to precompute tokens, return 0 directly.
+
+```python
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:
+ """
+```
+
+- Validate Model Credentials
+
+Similar to vendor credential validation, but specific to a single model.
+
+```python
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+```
+
+- Map Invoke Errors
+
+When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
+
+Runtime Errors:
+
+- `InvokeConnectionError` Connection error
+
+- `InvokeServerUnavailableError` Service provider unavailable
+- `InvokeRateLimitError` Rate limit reached
+- `InvokeAuthorizationError` Authorization failed
+- `InvokeBadRequestError` Parameter error
+
+```python
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+```
+
+For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
\ No newline at end of file
diff --git a/api/core/model_runtime/docs/en_US/provider_scale_out.md b/api/core/model_runtime/docs/en_US/provider_scale_out.md
index ba356c5cab..07be5811d3 100644
--- a/api/core/model_runtime/docs/en_US/provider_scale_out.md
+++ b/api/core/model_runtime/docs/en_US/provider_scale_out.md
@@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp
en_US: Enter your API URL
```
-You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
+You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
### Implementing Provider Code
diff --git a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
index b34544c789..78aad8876f 100644
--- a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
+++ b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
@@ -117,7 +117,7 @@ model_credential_schema:
en_US: Enter your API Base
```
-也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
+也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
#### 实现供应商代码
diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index 80db22ea84..89fccef659 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -40,3 +40,4 @@
- fireworks
- mixedbread
- nomic
+- voyage
diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
deleted file mode 100644
index d9c5726592..0000000000
--- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
+++ /dev/null
@@ -1,238 +0,0 @@
-import json
-import logging
-import time
-from typing import Optional
-
-import boto3
-from botocore.config import Config
-from botocore.exceptions import (
- ClientError,
- EndpointConnectionError,
- NoRegionError,
- ServiceNotInRegionError,
- UnknownServiceError,
-)
-
-from core.embedding.embedding_constant import EmbeddingInputType
-from core.model_runtime.entities.model_entities import PriceType
-from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
-from core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-
-logger = logging.getLogger(__name__)
-
-
-class BedrockTextEmbeddingModel(TextEmbeddingModel):
- def _invoke(
- self,
- model: str,
- credentials: dict,
- texts: list[str],
- user: Optional[str] = None,
- input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
- ) -> TextEmbeddingResult:
- """
- Invoke text embedding model
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :param user: unique user id
- :param input_type: input type
- :return: embeddings result
- """
- client_config = Config(region_name=credentials["aws_region"])
-
- bedrock_runtime = boto3.client(
- service_name="bedrock-runtime",
- config=client_config,
- aws_access_key_id=credentials.get("aws_access_key_id"),
- aws_secret_access_key=credentials.get("aws_secret_access_key"),
- )
-
- embeddings = []
- token_usage = 0
-
- model_prefix = model.split(".")[0]
-
- if model_prefix == "amazon":
- for text in texts:
- body = {
- "inputText": text,
- }
- response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
- embeddings.extend([response_body.get("embedding")])
- token_usage += response_body.get("inputTextTokenCount")
- logger.warning(f"Total Tokens: {token_usage}")
- result = TextEmbeddingResult(
- model=model,
- embeddings=embeddings,
- usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
- )
- return result
-
- if model_prefix == "cohere":
- input_type = "search_document" if len(texts) > 1 else "search_query"
- for text in texts:
- body = {
- "texts": [text],
- "input_type": input_type,
- }
- response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
- embeddings.extend(response_body.get("embeddings"))
- token_usage += len(text)
- result = TextEmbeddingResult(
- model=model,
- embeddings=embeddings,
- usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
- )
- return result
-
- # others
- raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
-
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :return:
- """
- num_tokens = 0
- for text in texts:
- num_tokens += self._get_num_tokens_by_gpt2(text)
- return num_tokens
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
-
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
- The value is the md = genai.GenerativeModel(model) error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke emd = genai.GenerativeModel(model) error mapping
- """
- return {
- InvokeConnectionError: [],
- InvokeServerUnavailableError: [],
- InvokeRateLimitError: [],
- InvokeAuthorizationError: [],
- InvokeBadRequestError: [],
- }
-
- def _create_payload(
- self,
- model_prefix: str,
- texts: list[str],
- model_parameters: dict,
- stop: Optional[list[str]] = None,
- stream: bool = True,
- ):
- """
- Create payload for bedrock api call depending on model provider
- """
- payload = {}
-
- if model_prefix == "amazon":
- payload["inputText"] = texts
-
- def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
- """
- Calculate response usage
-
- :param model: model name
- :param credentials: model credentials
- :param tokens: input tokens
- :return: usage
- """
- # get input price info
- input_price_info = self.get_price(
- model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
- )
-
- # transform usage
- usage = EmbeddingUsage(
- tokens=tokens,
- total_tokens=tokens,
- unit_price=input_price_info.unit_price,
- price_unit=input_price_info.unit,
- total_price=input_price_info.total_amount,
- currency=input_price_info.currency,
- latency=time.perf_counter() - self.started_at,
- )
-
- return usage
-
- def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
- """
- Map client error to invoke error
-
- :param error_code: error code
- :param error_msg: error message
- :return: invoke error
- """
-
- if error_code == "AccessDeniedException":
- return InvokeAuthorizationError(error_msg)
- elif error_code in {"ResourceNotFoundException", "ValidationException"}:
- return InvokeBadRequestError(error_msg)
- elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
- return InvokeRateLimitError(error_msg)
- elif error_code in {
- "ModelTimeoutException",
- "ModelErrorException",
- "InternalServerException",
- "ModelNotReadyException",
- }:
- return InvokeServerUnavailableError(error_msg)
- elif error_code == "ModelStreamErrorException":
- return InvokeConnectionError(error_msg)
-
- return InvokeError(error_msg)
-
- def _invoke_bedrock_embedding(
- self,
- model: str,
- bedrock_runtime,
- body: dict,
- ):
- accept = "application/json"
- content_type = "application/json"
- try:
- response = bedrock_runtime.invoke_model(
- body=json.dumps(body), modelId=model, accept=accept, contentType=content_type
- )
- response_body = json.loads(response.get("body").read().decode("utf-8"))
- return response_body
- except ClientError as ex:
- error_code = ex.response["Error"]["Code"]
- full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
- raise self._map_client_to_invoke_error(error_code, full_error_msg)
-
- except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
- raise InvokeConnectionError(str(ex))
-
- except UnknownServiceError as ex:
- raise InvokeServerUnavailableError(str(ex))
-
- except Exception as ex:
- raise InvokeError(str(ex))
diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
deleted file mode 100644
index 1e86f351c8..0000000000
--- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
+++ /dev/null
@@ -1,223 +0,0 @@
-import json
-import time
-from decimal import Decimal
-from typing import Optional
-from urllib.parse import urljoin
-
-import numpy as np
-import requests
-
-from core.embedding.embedding_constant import EmbeddingInputType
-from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import (
- AIModelEntity,
- FetchFrom,
- ModelPropertyKey,
- ModelType,
- PriceConfig,
- PriceType,
-)
-from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
-
-
-class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
- """
- Model class for an OpenAI API-compatible text embedding model.
- """
-
- def _invoke(
- self,
- model: str,
- credentials: dict,
- texts: list[str],
- user: Optional[str] = None,
- input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
- ) -> TextEmbeddingResult:
- """
- Invoke text embedding model
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :param user: unique user id
- :param input_type: input type
- :return: embeddings result
- """
-
- # Prepare headers and payload for the request
- headers = {"Content-Type": "application/json"}
-
- api_key = credentials.get("api_key")
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
- endpoint_url = "https://cloud.perfxlab.cn/v1/"
- else:
- endpoint_url = credentials.get("endpoint_url")
- if not endpoint_url.endswith("/"):
- endpoint_url += "/"
-
- endpoint_url = urljoin(endpoint_url, "embeddings")
-
- extra_model_kwargs = {}
- if user:
- extra_model_kwargs["user"] = user
-
- extra_model_kwargs["encoding_format"] = "float"
-
- # get model properties
- context_size = self._get_context_size(model, credentials)
- max_chunks = self._get_max_chunks(model, credentials)
-
- inputs = []
- indices = []
- used_tokens = 0
-
- for i, text in enumerate(texts):
- # Here token count is only an approximation based on the GPT2 tokenizer
- # TODO: Optimize for better token estimation and chunking
- num_tokens = self._get_num_tokens_by_gpt2(text)
-
- if num_tokens >= context_size:
- cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
- # if num tokens is larger than context length, only use the start
- inputs.append(text[0:cutoff])
- else:
- inputs.append(text)
- indices += [i]
-
- batched_embeddings = []
- _iter = range(0, len(inputs), max_chunks)
-
- for i in _iter:
- # Prepare the payload for the request
- payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
-
- # Make the request to the OpenAI API
- response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
-
- response.raise_for_status() # Raise an exception for HTTP errors
- response_data = response.json()
-
- # Extract embeddings and used tokens from the response
- embeddings_batch = [data["embedding"] for data in response_data["data"]]
- embedding_used_tokens = response_data["usage"]["total_tokens"]
-
- used_tokens += embedding_used_tokens
- batched_embeddings += embeddings_batch
-
- # calc usage
- usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
-
- return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
-
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
- """
- Approximate number of tokens for given messages using GPT2 tokenizer
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :return:
- """
- return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
- try:
- headers = {"Content-Type": "application/json"}
-
- api_key = credentials.get("api_key")
-
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
- endpoint_url = "https://cloud.perfxlab.cn/v1/"
- else:
- endpoint_url = credentials.get("endpoint_url")
- if not endpoint_url.endswith("/"):
- endpoint_url += "/"
-
- endpoint_url = urljoin(endpoint_url, "embeddings")
-
- payload = {"input": "ping", "model": model}
-
- response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
-
- if response.status_code != 200:
- raise CredentialsValidateFailedError(
- f"Credentials validation failed with status code {response.status_code}"
- )
-
- try:
- json_result = response.json()
- except json.JSONDecodeError as e:
- raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
-
- if "model" not in json_result:
- raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
- except CredentialsValidateFailedError:
- raise
- except Exception as ex:
- raise CredentialsValidateFailedError(str(ex))
-
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
- """
- generate custom model entities from credentials
- """
- entity = AIModelEntity(
- model=model,
- label=I18nObject(en_US=model),
- model_type=ModelType.TEXT_EMBEDDING,
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_properties={
- ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
- ModelPropertyKey.MAX_CHUNKS: 1,
- },
- parameter_rules=[],
- pricing=PriceConfig(
- input=Decimal(credentials.get("input_price", 0)),
- unit=Decimal(credentials.get("unit", 0)),
- currency=credentials.get("currency", "USD"),
- ),
- )
-
- return entity
-
- def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
- """
- Calculate response usage
-
- :param model: model name
- :param credentials: model credentials
- :param tokens: input tokens
- :return: usage
- """
- # get input price info
- input_price_info = self.get_price(
- model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
- )
-
- # transform usage
- usage = EmbeddingUsage(
- tokens=tokens,
- total_tokens=tokens,
- unit_price=input_price_info.unit_price,
- price_unit=input_price_info.unit,
- total_price=input_price_info.total_amount,
- currency=input_price_info.currency,
- latency=time.perf_counter() - self.started_at,
- )
-
- return usage
diff --git a/api/core/model_runtime/model_providers/voyage/__init__.py b/api/core/model_runtime/model_providers/voyage/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..a961f5e435
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
@@ -0,0 +1,21 @@
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..2c4e121dd7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
@@ -0,0 +1,8 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/__init__.py b/api/core/model_runtime/model_providers/voyage/rerank/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
new file mode 100644
index 0000000000..9c894eda85
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
@@ -0,0 +1,4 @@
+model: rerank-1
+model_type: rerank
+model_properties:
+ context_size: 8000
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
new file mode 100644
index 0000000000..b052d6f000
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
@@ -0,0 +1,4 @@
+model: rerank-lite-1
+model_type: rerank
+model_properties:
+ context_size: 4000
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank.py b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py
new file mode 100644
index 0000000000..33fdebbb45
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py
@@ -0,0 +1,123 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class VoyageRerankModel(RerankModel):
+ """
+ Model class for Voyage rerank model.
+ """
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ query: str,
+ docs: list[str],
+ score_threshold: Optional[float] = None,
+ top_n: Optional[int] = None,
+ user: Optional[str] = None,
+ ) -> RerankResult:
+ """
+ Invoke rerank model
+ :param model: model name
+ :param credentials: model credentials
+ :param query: search query
+ :param docs: docs for reranking
+ :param score_threshold: score threshold
+ :param top_n: top n documents to return
+ :param user: unique user id
+ :return: rerank result
+ """
+ if len(docs) == 0:
+ return RerankResult(model=model, docs=[])
+
+ base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
+ base_url = base_url.removesuffix("/")
+
+ try:
+ response = httpx.post(
+ base_url + "/rerank",
+ json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
+ headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ results = response.json()
+
+ rerank_documents = []
+ for result in results["data"]:
+ rerank_document = RerankDocument(
+ index=result["index"],
+ text=result["document"],
+ score=result["relevance_score"],
+ )
+ if score_threshold is None or result["relevance_score"] >= score_threshold:
+ rerank_documents.append(rerank_document)
+
+ return RerankResult(model=model, docs=rerank_documents)
+ except httpx.HTTPStatusError as e:
+ raise InvokeServerUnavailableError(str(e))
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(
+ model=model,
+ credentials=credentials,
+ query="What is the capital of the United States?",
+ docs=[
+ "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+ "Census, Carson City had a population of 55,274.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+ "are a political division controlled by the United States. Its capital is Saipan.",
+ ],
+ score_threshold=0.8,
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ """
+ return {
+ InvokeConnectionError: [httpx.ConnectError],
+ InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+ InvokeRateLimitError: [],
+ InvokeAuthorizationError: [httpx.HTTPStatusError],
+ InvokeBadRequestError: [httpx.RequestError],
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.RERANK,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py b/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..a8a4d3c15b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
@@ -0,0 +1,172 @@
+import time
+from json import JSONDecodeError, dumps
+from typing import Optional
+
+import requests
+
+from core.embedding.embedding_constant import EmbeddingInputType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+
+class VoyageTextEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for Voyage text embedding model.
+ """
+
+ api_base: str = "https://api.voyageai.com/v1"
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ texts: list[str],
+ user: Optional[str] = None,
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+ ) -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :param input_type: input type
+ :return: embeddings result
+ """
+ api_key = credentials["api_key"]
+ if not api_key:
+ raise CredentialsValidateFailedError("api_key is required")
+
+ base_url = credentials.get("base_url", self.api_base)
+ base_url = base_url.removesuffix("/")
+
+ url = base_url + "/embeddings"
+ headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
+ voyage_input_type = "null"
+ if input_type is not None:
+ voyage_input_type = input_type.value
+ data = {"model": model, "input": texts, "input_type": voyage_input_type}
+
+ try:
+ response = requests.post(url, headers=headers, data=dumps(data))
+ except Exception as e:
+ raise InvokeConnectionError(str(e))
+
+ if response.status_code != 200:
+ try:
+ resp = response.json()
+ msg = resp["detail"]
+ if response.status_code == 401:
+ raise InvokeAuthorizationError(msg)
+ elif response.status_code == 429:
+ raise InvokeRateLimitError(msg)
+ elif response.status_code == 500:
+ raise InvokeServerUnavailableError(msg)
+ else:
+ raise InvokeBadRequestError(msg)
+ except JSONDecodeError as e:
+ raise InvokeServerUnavailableError(
+ f"Failed to convert response to json: {e} with text: {response.text}"
+ )
+
+ try:
+ resp = response.json()
+ embeddings = resp["data"]
+ usage = resp["usage"]
+ except Exception as e:
+ raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
+
+ usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
+
+ result = TextEmbeddingResult(
+ model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
+ )
+
+ return result
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(model=model, credentials=credentials, texts=["ping"])
+ except Exception as e:
+ raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ return {
+ InvokeConnectionError: [InvokeConnectionError],
+ InvokeServerUnavailableError: [InvokeServerUnavailableError],
+ InvokeRateLimitError: [InvokeRateLimitError],
+ InvokeAuthorizationError: [InvokeAuthorizationError],
+ InvokeBadRequestError: [KeyError, InvokeBadRequestError],
+ }
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at,
+ )
+
+ return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
new file mode 100644
index 0000000000..a06bb7639f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
@@ -0,0 +1,8 @@
+model: voyage-3-lite
+model_type: text-embedding
+model_properties:
+ context_size: 32000
+pricing:
+ input: '0.00002'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
new file mode 100644
index 0000000000..117afbcaf3
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
@@ -0,0 +1,8 @@
+model: voyage-3
+model_type: text-embedding
+model_properties:
+ context_size: 32000
+pricing:
+ input: '0.00006'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/voyage/voyage.py b/api/core/model_runtime/model_providers/voyage/voyage.py
new file mode 100644
index 0000000000..3e33b45e11
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/voyage.py
@@ -0,0 +1,28 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VoyageProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
+
+ # Use `voyage-3` model for validate,
+ # no matter what model you pass in, text completion model or chat model
+ model_instance.validate_credentials(model="voyage-3", credentials=credentials)
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
+ raise ex
diff --git a/api/core/model_runtime/model_providers/voyage/voyage.yaml b/api/core/model_runtime/model_providers/voyage/voyage.yaml
new file mode 100644
index 0000000000..c64707800e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/voyage.yaml
@@ -0,0 +1,31 @@
+provider: voyage
+label:
+ en_US: Voyage
+description:
+ en_US: Embedding and Rerank Model Supported
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#EFFDFD"
+help:
+ title:
+ en_US: Get your API key from Voyage AI
+ zh_Hans: 从 Voyage 获取 API Key
+ url:
+ en_US: https://dash.voyageai.com/
+supported_model_types:
+ - text-embedding
+ - rerank
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
deleted file mode 100644
index 14a529dddf..0000000000
--- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import time
-from typing import Optional
-
-from core.embedding.embedding_constant import EmbeddingInputType
-from core.model_runtime.entities.model_entities import PriceType
-from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
-from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
-
-
-class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
- """
- Model class for ZhipuAI text embedding model.
- """
-
- def _invoke(
- self,
- model: str,
- credentials: dict,
- texts: list[str],
- user: Optional[str] = None,
- input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
- ) -> TextEmbeddingResult:
- """
- Invoke text embedding model
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :param user: unique user id
- :param input_type: input type
- :return: embeddings result
- """
- credentials_kwargs = self._to_credential_kwargs(credentials)
- client = ZhipuAI(api_key=credentials_kwargs["api_key"])
-
- embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
-
- return TextEmbeddingResult(
- embeddings=embeddings,
- usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
- model=model,
- )
-
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :return:
- """
- if len(texts) == 0:
- return 0
-
- total_num_tokens = 0
- for text in texts:
- total_num_tokens += self._get_num_tokens_by_gpt2(text)
-
- return total_num_tokens
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
- try:
- # transform credentials to kwargs for model instance
- credentials_kwargs = self._to_credential_kwargs(credentials)
- client = ZhipuAI(api_key=credentials_kwargs["api_key"])
-
- # call embedding model
- self.embed_documents(
- model=model,
- client=client,
- texts=["ping"],
- )
- except Exception as ex:
- raise CredentialsValidateFailedError(str(ex))
-
- def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]:
- """Call out to ZhipuAI's embedding endpoint.
-
- Args:
- texts: The list of texts to embed.
-
- Returns:
- List of embeddings, one for each text.
- """
- embeddings = []
- embedding_used_tokens = 0
-
- for text in texts:
- response = client.embeddings.create(model=model, input=text)
- data = response.data[0]
- embeddings.append(data.embedding)
- embedding_used_tokens += response.usage.total_tokens
-
- return [list(map(float, e)) for e in embeddings], embedding_used_tokens
-
- def embed_query(self, text: str) -> list[float]:
- """Call out to ZhipuAI's embedding endpoint.
-
- Args:
- text: The text to embed.
-
- Returns:
- Embeddings for the text.
- """
- return self.embed_documents([text])[0]
-
- def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
- """
- Calculate response usage
-
- :param model: model name
- :param tokens: input tokens
- :return: usage
- """
- # get input price info
- input_price_info = self.get_price(
- model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
- )
-
- # transform usage
- usage = EmbeddingUsage(
- tokens=tokens,
- total_tokens=tokens,
- unit_price=input_price_info.unit_price,
- price_unit=input_price_info.unit,
- total_price=input_price_info.total_amount,
- currency=input_price_info.currency,
- latency=time.perf_counter() - self.started_at,
- )
-
- return usage
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index 3073100746..a0153c1e58 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -45,7 +45,7 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
- keywords_list = kwargs.get("keywords_list", None)
+ keywords_list = kwargs.get("keywords_list")
for i in range(len(texts)):
text = texts[i]
if keywords_list:
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index 54f6a76e16..5af45e1e50 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -14,7 +14,7 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
- document_ids = kwargs.get("document_ids", None)
+ document_ids = kwargs.get("document_ids")
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
diff --git a/api/poetry.lock b/api/poetry.lock
index bce21fb547..85c68cd75f 100644
--- a/api/poetry.lock
+++ b/api/poetry.lock
@@ -8074,29 +8074,29 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
-version = "0.6.5"
+version = "0.6.8"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"},
- {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"},
- {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"},
- {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"},
- {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"},
- {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"},
- {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"},
- {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"},
- {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"},
- {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"},
- {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"},
- {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"},
+ {file = "ruff-0.6.8-py3-none-linux_armv6l.whl", hash = "sha256:77944bca110ff0a43b768f05a529fecd0706aac7bcce36d7f1eeb4cbfca5f0f2"},
+ {file = "ruff-0.6.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27b87e1801e786cd6ede4ada3faa5e254ce774de835e6723fd94551464c56b8c"},
+ {file = "ruff-0.6.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd48f945da2a6334f1793d7f701725a76ba93bf3d73c36f6b21fb04d5338dcf5"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:677e03c00f37c66cea033274295a983c7c546edea5043d0c798833adf4cf4c6f"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9f1476236b3eacfacfc0f66aa9e6cd39f2a624cb73ea99189556015f27c0bdeb"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f5a2f17c7d32991169195d52a04c95b256378bbf0de8cb98478351eb70d526f"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5fd0d4b7b1457c49e435ee1e437900ced9b35cb8dc5178921dfb7d98d65a08d0"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8034b19b993e9601f2ddf2c517451e17a6ab5cdb1c13fdff50c1442a7171d87"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cfb227b932ba8ef6e56c9f875d987973cd5e35bc5d05f5abf045af78ad8e098"},
+ {file = "ruff-0.6.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef0411eccfc3909269fed47c61ffebdcb84a04504bafa6b6df9b85c27e813b0"},
+ {file = "ruff-0.6.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:007dee844738c3d2e6c24ab5bc7d43c99ba3e1943bd2d95d598582e9c1b27750"},
+ {file = "ruff-0.6.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ce60058d3cdd8490e5e5471ef086b3f1e90ab872b548814e35930e21d848c9ce"},
+ {file = "ruff-0.6.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1085c455d1b3fdb8021ad534379c60353b81ba079712bce7a900e834859182fa"},
+ {file = "ruff-0.6.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:70edf6a93b19481affd287d696d9e311388d808671bc209fb8907b46a8c3af44"},
+ {file = "ruff-0.6.8-py3-none-win32.whl", hash = "sha256:792213f7be25316f9b46b854df80a77e0da87ec66691e8f012f887b4a671ab5a"},
+ {file = "ruff-0.6.8-py3-none-win_amd64.whl", hash = "sha256:ec0517dc0f37cad14a5319ba7bba6e7e339d03fbf967a6d69b0907d61be7a263"},
+ {file = "ruff-0.6.8-py3-none-win_arm64.whl", hash = "sha256:8d3bb2e3fbb9875172119021a13eed38849e762499e3cfde9588e4b4d70968dc"},
+ {file = "ruff-0.6.8.tar.gz", hash = "sha256:a5bf44b1aa0adaf6d9d20f86162b34f7c593bfedabc51239953e446aefc8ce18"},
]
[[package]]
@@ -10501,4 +10501,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
-content-hash = "69b42bb1ff033f14e199fee8335356275099421d72bbd7037b7a991ea65cae08"
+content-hash = "c4580c22e2b220c8c80dbc3f765060a09e14874ed29b690c13a533bf0365e789"
diff --git a/api/pyproject.toml b/api/pyproject.toml
index f004865d5f..e737761f3b 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
+VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa"
[tool.poetry]
name = "dify-api"
@@ -286,4 +287,4 @@ optional = true
[tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0"
-ruff = "~0.6.5"
+ruff = "~0.6.8"
diff --git a/api/tests/integration_tests/model_runtime/voyage/__init__.py b/api/tests/integration_tests/model_runtime/voyage/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_provider.py b/api/tests/integration_tests/model_runtime/voyage/test_provider.py
new file mode 100644
index 0000000000..08978c88a9
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_provider.py
@@ -0,0 +1,25 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
+
+
+def test_validate_provider_credentials():
+ provider = VoyageProvider()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 1},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_rerank.py b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py
new file mode 100644
index 0000000000..e97a9e4c81
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py
@@ -0,0 +1,92 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
+
+
+def test_validate_credentials():
+ model = VoyageRerankModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model="rerank-lite-1",
+ credentials={"api_key": "invalid_key"},
+ )
+ with patch("httpx.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {
+ "relevance_score": 0.546875,
+ "index": 0,
+ "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
+ "States Census, Carson City had a population of 55,274.",
+ },
+ {
+ "relevance_score": 0.4765625,
+ "index": 1,
+ "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
+ "Pacific Ocean that are a political division controlled by the United States. Its "
+ "capital is Saipan.",
+ },
+ ],
+ "model": "rerank-lite-1",
+ "usage": {"total_tokens": 96},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ model.validate_credentials(
+ model="rerank-lite-1",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ )
+
+
+def test_invoke_model():
+ model = VoyageRerankModel()
+ with patch("httpx.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {
+ "relevance_score": 0.84375,
+ "index": 0,
+ "document": "Kasumi is a girl name of Japanese origin meaning mist.",
+ },
+ {
+ "relevance_score": 0.4765625,
+ "index": 1,
+ "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
+ "leads a team named PopiParty.",
+ },
+ ],
+ "model": "rerank-lite-1",
+ "usage": {"total_tokens": 59},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ result = model.invoke(
+ model="rerank-lite-1",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ query="Who is Kasumi?",
+ docs=[
+ "Kasumi is a girl name of Japanese origin meaning mist.",
+ "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
+ "PopiParty.",
+ ],
+ score_threshold=0.5,
+ )
+
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 1
+ assert result.docs[0].index == 0
+ assert result.docs[0].score >= 0.5
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py
new file mode 100644
index 0000000000..75719672a9
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py
@@ -0,0 +1,70 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
+
+
+def test_validate_credentials():
+ model = VoyageTextEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 1},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
+
+
+def test_invoke_model():
+ model = VoyageTextEmbeddingModel()
+
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
+ {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
+ ],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 2},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ result = model.invoke(
+ model="voyage-3",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ texts=["hello", "world"],
+ user="abc-123",
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens == 2
+
+
+def test_get_num_tokens():
+ model = VoyageTextEmbeddingModel()
+
+ num_tokens = model.get_num_tokens(
+ model="voyage-3",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ texts=["ping"],
+ )
+
+ assert num_tokens == 1
diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py
index d3c1f3101c..42cf87e317 100644
--- a/api/tests/integration_tests/tools/__mock/http.py
+++ b/api/tests/integration_tests/tools/__mock/http.py
@@ -17,7 +17,7 @@ class MockedHttp:
request = httpx.Request(
method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
)
- data = kwargs.get("data", None)
+ data = kwargs.get("data")
resp = json.dumps(data).encode("utf-8") if data else b"OK"
response = httpx.Response(
status_code=200,
diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py
index f1ab23b002..ec013183b7 100644
--- a/api/tests/integration_tests/workflow/nodes/__mock/http.py
+++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py
@@ -22,8 +22,8 @@ class MockedHttp:
return response
# get data, files
- data = kwargs.get("data", None)
- files = kwargs.get("files", None)
+ data = kwargs.get("data")
+ files = kwargs.get("files")
if data is not None:
resp = dumps(data).encode("utf-8")
elif files is not None:
diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh
index b60ff64fdc..63891eb9f8 100755
--- a/dev/pytest/pytest_model_runtime.sh
+++ b/dev/pytest/pytest_model_runtime.sh
@@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/upstage \
api/tests/integration_tests/model_runtime/fireworks \
api/tests/integration_tests/model_runtime/nomic \
- api/tests/integration_tests/model_runtime/mixedbread
+ api/tests/integration_tests/model_runtime/mixedbread \
+ api/tests/integration_tests/model_runtime/voyage
\ No newline at end of file