Merge main

This commit is contained in:
Yeuoly 2024-09-10 14:05:20 +08:00
commit 9c7bcd5abc
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
650 changed files with 15950 additions and 4747 deletions

View File

@ -20,7 +20,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: api/**
@ -66,7 +66,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: web/**
@ -97,7 +97,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: |
**.sh
@ -107,7 +107,7 @@ jobs:
dev/**
- name: Super-linter
uses: super-linter/super-linter/slim@v6
uses: super-linter/super-linter/slim@v7
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@ -0,0 +1,54 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@ -72,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
@ -79,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@ -100,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1

View File

@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Annotated, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution servcie",
description="endpoint URL of code execution service",
default="http://sandbox:8194",
)
@ -230,20 +230,17 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
description="",
default=300,
)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",
@ -431,7 +428,7 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
default=None,
)

View File

@ -1,7 +1,7 @@
from typing import Any, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
@ -184,6 +201,8 @@ class MiddlewareConfig(
AzureBlobStorageConfig,
GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic_settings import BaseSettings
@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
description="whether to use SSL for Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -38,3 +38,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
default=None,
)

View File

@ -0,0 +1,29 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@ -0,0 +1,34 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine TOS Region",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic import Field
from pydantic_settings import BaseSettings
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
Milvus configs
"""
MILVUS_HOST: Optional[str] = Field(
description="Milvus host",
default=None,
MILVUS_URI: Optional[str] = Field(
description="Milvus uri",
default="http://127.0.0.1:19530",
)
MILVUS_PORT: PositiveInt = Field(
description="Milvus RestFul API port",
default=9091,
MILVUS_TOKEN: Optional[str] = Field(
description="Milvus token",
default=None,
)
MILVUS_USER: Optional[str] = Field(
@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
default=None,
)
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description="Milvus database, default to `default`",
default="default",

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.7.2",
default="0.7.3",
)
COMMIT_SHA: str = Field(

File diff suppressed because one or more lines are too long

View File

@ -174,6 +174,7 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args()
app_service = AppService()

View File

@ -201,7 +201,11 @@ class ChatConversationApi(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at >= start_datetime_utc)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
@ -210,7 +214,11 @@ class ChatConversationApi(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at < end_datetime_utc)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(

View File

@ -34,6 +34,7 @@ def parse_app_site_args():
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
@ -68,6 +69,7 @@ class AppSite(Resource):
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:

View File

@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields

View File

@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(

View File

@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segemtns is required"}, 400
return {"error": "Segments is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@ -39,6 +39,7 @@ class AppSiteApi(WebApiResource):
"default_language": fields.String,
"prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
app_fields = {

View File

@ -93,7 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
)
)

View File

@ -65,7 +65,7 @@ class Extensible:
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
position_map[extension_name] = position
position_map[extension_name] = position
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")

View File

@ -79,7 +79,7 @@ def is_filtered(
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Check if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded

View File

@ -16,9 +16,7 @@ from configs import dify_config
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -255,11 +253,8 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
@ -286,54 +281,22 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
@ -531,7 +494,7 @@ class IndexingRunner:
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]

View File

@ -87,7 +87,7 @@ Here is a task description for which I would like you to create a high-quality p
{{TASK_DESCRIPTION}}
</task_description>
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
- Do not inlcude <input> or <output> section and variables in the prompt, assume user will add them at their own will.
- Do not include <input> or <output> section and variables in the prompt, assume user will add them at their own will.
- Clear instructions for the AI that will be using this prompt, demarcated with <instructions> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
- Relevant examples if needed to clarify the task further, demarcated with <example> tags. Do not include variables in the prompt. Give three pairs of input and output examples.
- Include other relevant sections demarcated with appropriate XML tags like <examples>, <instructions>.

View File

@ -52,7 +52,7 @@
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `language` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraphwise by defaultavailable for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraph-wise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
@ -150,7 +150,7 @@
- `input` (float) Input price, i.e., Prompt price
- `output` (float) Output price, i.e., returned content price
- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
- `currency` (string) Currency unit
### ProviderCredentialSchema

View File

@ -33,6 +33,22 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'max': 1.0,
'precision': 2,
},
DefaultParameterName.TOP_K: {
'label': {
'en_US': 'Top K',
'zh_Hans': 'Top K',
},
'type': 'int',
'help': {
'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.',
'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。',
},
'required': False,
'default': 50,
'min': 1,
'max': 100,
'precision': 0,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
'en_US': 'Presence Penalty',

View File

@ -85,12 +85,13 @@ class ModelFeature(Enum):
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum):
class DefaultParameterName(str, Enum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"

View File

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class TTSModel(AIModel):
"""
Model class for ttstext model.
Model class for TTS model.
"""
model_type: ModelType = ModelType.TTS

View File

@ -19,9 +19,9 @@ class AnthropicProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `claude-instant-1` model for validate,
# Use `claude-3-opus-20240229` model for validate,
model_instance.validate_credentials(
model='claude-instant-1.2',
model='claude-3-opus-20240229',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@ -33,3 +33,4 @@ pricing:
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -637,7 +637,19 @@ LLM_BASE_MODELS = [
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
options=['text', 'json_object', 'json_schema']
),
ParameterRule(
name='json_schema',
label=I18nObject(
en_US='JSON Schema'
),
type='text',
help=I18nObject(
zh_Hans='设置返回的json schemallm将按照它返回',
en_US='Set a response json schema will ensure LLM to adhere it.'
),
required=False
),
],
pricing=PriceConfig(
@ -800,6 +812,94 @@ LLM_BASE_MODELS = [
)
)
),
AzureBaseModel(
base_model_name='gpt-4o-2024-08-06',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
),
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object', 'json_schema']
),
ParameterRule(
name='json_schema',
label=I18nObject(
en_US='JSON Schema'
),
type='text',
help=I18nObject(
zh_Hans='设置返回的json schemallm将按照它返回',
en_US='Set a response json schema will ensure LLM to adhere it.'
),
required=False
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4-turbo',
entity=AIModelEntity(

View File

@ -138,6 +138,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-08-06
value: gpt-4o-2024-08-06
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo

View File

@ -1,4 +1,5 @@
import copy
import json
import logging
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
@ -276,12 +277,18 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_object":
response_format = {"type": "json_object"}
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
response_format = {"type": "text"}
model_parameters["response_format"] = response_format
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}

View File

@ -27,11 +27,3 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: secret_key
label:
en_US: Secret Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 Secret Key
en_US: Enter your Secret Key

View File

@ -43,3 +43,4 @@ parameter_rules:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
deprecated: true

View File

@ -43,3 +43,4 @@ parameter_rules:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
deprecated: true

View File

@ -4,36 +4,32 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 192000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
default: 2048
- name: with_search_enhance
label:
zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 128000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 32000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强

View File

@ -4,36 +4,44 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 32000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强

View File

@ -1,11 +1,10 @@
from collections.abc import Generator
from enum import Enum
from hashlib import md5
from json import dumps, loads
from typing import Any, Union
import json
from collections.abc import Iterator
from typing import Any, Optional, Union
from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
)
class BaichuanMessage:
class Role(Enum):
USER = 'user'
ASSISTANT = 'assistant'
# Baichuan does not have system message
_SYSTEM = 'system'
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
stop_reason: str = ''
def to_dict(self) -> dict[str, Any]:
return {
'role': self.role,
'content': self.content,
}
def __init__(self, content: str, role: str = 'user') -> None:
self.content = content
self.role = role
class BaichuanModel:
api_key: str
secret_key: str
def __init__(self, api_key: str, secret_key: str = '') -> None:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.secret_key = secret_key
def _model_mapping(self, model: str) -> str:
@property
def _model_mapping(self) -> dict:
return {
'baichuan2-turbo': 'Baichuan2-Turbo',
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
'baichuan2-53b': 'Baichuan2-53B',
'baichuan3-turbo': 'Baichuan3-Turbo',
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
'baichuan4': 'Baichuan4',
}[model]
"baichuan2-turbo": "Baichuan2-Turbo",
"baichuan3-turbo": "Baichuan3-Turbo",
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
"baichuan4": "Baichuan4",
}
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
@property
def request_headers(self) -> dict[str, Any]:
return {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
def _build_parameters(
self,
model: str,
stream: bool,
messages: list[dict],
parameters: dict[str, Any],
tools: Optional[list[PromptMessageTool]] = None,
) -> dict[str, Any]:
if model in self._model_mapping.keys():
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
# we need to rename it to res_format to get its value
if parameters.get("res_format") == "json_object":
parameters["response_format"] = {"type": "json_object"}
return message
def _handle_chat_stream_generate_response(self, response) -> Generator:
for line in response.iter_lines():
if not line:
continue
line = line.decode('utf-8')
# remove the first `data: ` prefix
if line.startswith('data:'):
line = line[5:].strip()
try:
data = loads(line)
except Exception as e:
if line.strip() == '[DONE]':
return
choices = data.get('choices', [])
# save stop reason temporarily
stop_reason = ''
for choice in choices:
if choice.get('finish_reason'):
stop_reason = choice['finish_reason']
if tools or parameters.get("with_search_enhance") is True:
parameters["tools"] = []
if len(choice['delta']['content']) == 0:
continue
yield BaichuanMessage(**choice['delta'])
# if there is usage, the response is the last one, yield it and return
if 'usage' in data:
message = BaichuanMessage(content='', role='assistant')
message.usage = {
'prompt_tokens': data['usage']['prompt_tokens'],
'completion_tokens': data['usage']['completion_tokens'],
'total_tokens': data['usage']['total_tokens'],
}
message.stop_reason = stop_reason
yield message
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any]) \
-> dict[str, Any]:
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
prompt_messages = []
for message in messages:
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
# check if the latest message is a user message
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
prompt_messages[-1]['content'] += message.content
else:
prompt_messages.append({
'content': message.content,
'role': BaichuanMessage.Role.USER.value,
})
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
prompt_messages.append({
'content': message.content,
'role': message.role,
})
# [baichuan] frequency_penalty must be between 1 and 2
if 'frequency_penalty' in parameters:
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
parameters['frequency_penalty'] = 1
# with_search_enhance is deprecated, use web_search instead
if parameters.get("with_search_enhance") is True:
parameters["tools"].append(
{
"type": "web_search",
"web_search": {"enable": True},
}
)
if tools:
for tool in tools:
parameters["tools"].append(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
)
# turbo api accepts flat parameters
return {
'model': self._model_mapping(model),
'stream': stream,
'messages': prompt_messages,
"model": self._model_mapping.get(model),
"stream": stream,
"messages": messages,
**parameters,
}
else:
raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
# there is no secret key for turbo api
return {
'Content-Type': 'application/json',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
'Authorization': 'Bearer ' + self.api_key,
}
else:
raise BadRequestError(f"Unknown model: {model}")
def _calculate_md5(self, input_string):
return md5(input_string.encode('utf-8')).hexdigest()
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any], timeout: int) \
-> Union[Generator, BaichuanMessage]:
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
def generate(
self,
model: str,
stream: bool,
messages: list[dict],
parameters: dict[str, Any],
timeout: int,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[Iterator, dict]:
if model in self._model_mapping.keys():
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else:
raise BadRequestError(f"Unknown model: {model}")
try:
data = self._build_parameters(model, stream, messages, parameters)
headers = self._build_headers(model, data)
except KeyError:
raise InternalServerError(f"Failed to build parameters for model: {model}")
data = self._build_parameters(model, stream, messages, parameters, tools)
try:
response = post(
url=api_base,
headers=headers,
data=dumps(data),
headers=self.request_headers,
data=json.dumps(data),
timeout=timeout,
stream=stream
stream=stream,
)
except Exception as e:
raise InternalServerError(f"Failed to invoke model: {e}")
if response.status_code != 200:
try:
resp = response.json()
# try to parse error message
err = resp['error']['code']
msg = resp['error']['message']
err = resp["error"]["type"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
if err == 'invalid_api_key':
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == 'insufficient_quota':
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif 'rate' in err:
elif err == "invalid_request_error":
raise BadRequestError(msg)
elif "rate" in err:
raise RateLimitReachedError(msg)
elif 'internal' in err:
elif "internal" in err:
raise InternalServerError(msg)
elif err == 'api_key_empty':
elif err == "api_key_empty":
raise InvalidAPIKeyError(msg)
else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
if stream:
return self._handle_chat_stream_generate_response(response)
return response.iter_lines()
else:
return self._handle_chat_generate_response(response)
return response.json()

View File

@ -1,7 +1,12 @@
from collections.abc import Generator
import json
from collections.abc import Generator, Iterator
from typing import cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
@ -32,20 +37,41 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
)
class BaichuanLarguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
class BaichuanLanguageModel(LargeLanguageModel):
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stream=stream,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
def tokens(text: str):
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "user", "content": message.content}
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id
}
else:
raise ValueError(f"Unknown message type {type(message)}")
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
def validate_credentials(self, model: str, credentials: dict) -> None:
# ping
instance = BaichuanModel(
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
)
instance = BaichuanModel(api_key=credentials["api_key"])
try:
instance.generate(model=model, stream=False, messages=[
BaichuanMessage(content='ping', role='user')
], parameters={
'max_tokens': 1,
}, timeout=60)
instance.generate(
model=model,
stream=False,
messages=[{"content": "ping", "role": "user"}],
parameters={
"max_tokens": 1,
},
timeout=60,
)
except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
if tools is not None and len(tools) > 0:
raise InvokeBadRequestError("Baichuan model doesn't support tools")
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stream: bool = True,
) -> LLMResult | Generator:
instance = BaichuanModel(
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
)
# convert prompt messages to baichuan messages
messages = [
BaichuanMessage(
content=message.content if isinstance(message.content, str) else ''.join([
content.data for content in message.content
]),
role=message.role.value
) for message in prompt_messages
]
instance = BaichuanModel(api_key=credentials["api_key"])
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
# invoke model
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
timeout=60)
response = instance.generate(
model=model,
stream=stream,
messages=messages,
parameters=model_parameters,
timeout=60,
tools=tools,
)
if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_stream_response(
model, prompt_messages, credentials, response
)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_response(
model, prompt_messages, credentials, response
)
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: dict,
) -> LLMResult:
choices = response.get("choices", [])
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
if choices and choices[0]["finish_reason"] == "tool_calls":
for choice in choices:
for tool_call in choice["message"]["tool_calls"]:
tool = AssistantPromptMessage.ToolCall(
id=tool_call.get("id", ""),
type=tool_call.get("type", ""),
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.get("function", {}).get("name", ""),
arguments=tool_call.get("function", {}).get("arguments", "")
),
)
assistant_message.tool_calls.append(tool)
else:
for choice in choices:
assistant_message.content += choice["message"]["content"]
assistant_message.role = choice["message"]["role"]
usage = response.get("usage")
if usage:
# transform usage
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
completion_tokens = self._num_tokens_from_messages([assistant_message])
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _handle_chat_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: BaichuanMessage) -> LLMResult:
# convert baichuan message to llm result
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=response.usage['prompt_tokens'],
completion_tokens=response.usage['completion_tokens'])
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=response.content,
tool_calls=[]
),
message=assistant_message,
usage=usage,
)
def _handle_chat_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Generator[BaichuanMessage, None, None]) -> Generator:
for message in response:
if message.usage:
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=message.usage['prompt_tokens'],
completion_tokens=message.usage['completion_tokens'])
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Iterator,
) -> Generator:
for line in response:
if not line:
continue
line = line.decode("utf-8")
# remove the first `data: ` prefix
if line.startswith("data:"):
line = line[5:].strip()
try:
data = json.loads(line)
except Exception as e:
if line.strip() == "[DONE]":
return
choices = data.get("choices", [])
stop_reason = ""
for choice in choices:
if choice.get("finish_reason"):
stop_reason = choice["finish_reason"]
if len(choice["delta"]["content"]) == 0:
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
content=choice["delta"]["content"], tool_calls=[]
),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=stop_reason,
),
)
else:
# if there is usage, the response is the last one, yield it and return
if "usage" in data:
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=data["usage"]["prompt_tokens"],
completion_tokens=data["usage"]["completion_tokens"],
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
),
finish_reason=message.stop_reason if message.stop_reason else None,
message=AssistantPromptMessage(content="", tool_calls=[]),
usage=usage,
finish_reason=stop_reason,
),
)
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}

View File

@ -60,7 +60,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
token_usage = 0
for chunk in chunks:
# embeding chunk
# embedding chunk
chunk_embeddings, chunk_usage = self.embedding(
model=model,
api_key=api_key,

View File

@ -793,11 +793,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [],

View File

@ -130,11 +130,11 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [],

View File

@ -0,0 +1 @@


View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -0,0 +1,28 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class FishAudioProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
For debugging purposes, this method now always passes validation.
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TTS)
model_instance.validate_credentials(
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,76 @@
provider: fishaudio
label:
en_US: Fish Audio
description:
en_US: Models provided by Fish Audio, currently only support TTS.
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
icon_small:
en_US: fishaudio_s_en.svg
icon_large:
en_US: fishaudio_l_en.svg
background: "#E5E7EB"
help:
title:
en_US: Get your API key from Fish Audio
zh_Hans: 从 Fish Audio 获取你的 API Key
url:
en_US: https://fish.audio/go-api/
supported_model_types:
- tts
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: api_base
label:
en_US: API URL
type: text-input
required: false
default: https://api.fish.audio
placeholder:
en_US: Enter your API URL
zh_Hans: 在此输入您的 API URL
- variable: use_public_models
label:
en_US: Use Public Models
type: select
required: false
default: "false"
placeholder:
en_US: Toggle to use public models
zh_Hans: 切换以使用公共模型
options:
- value: "true"
label:
en_US: Allow Public Models
zh_Hans: 使用公共模型
- value: "false"
label:
en_US: Private Models Only
zh_Hans: 仅使用私有模型
- variable: latency
label:
en_US: Latency
type: select
required: false
default: "normal"
placeholder:
en_US: Toggle to choice latency
zh_Hans: 切换以调整延迟
options:
- value: "balanced"
label:
en_US: Low (may affect quality)
zh_Hans: 低延迟 (可能降低质量)
- value: "normal"
label:
en_US: Normal
zh_Hans: 标准

View File

@ -0,0 +1,174 @@
from typing import Optional
import httpx
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
class FishAudioText2SpeechModel(TTSModel):
"""
Model class for Fish.audio Text to Speech model.
"""
def get_tts_model_voices(
self, model: str, credentials: dict, language: Optional[str] = None
) -> list:
api_base = credentials.get("api_base", "https://api.fish.audio")
api_key = credentials.get("api_key")
use_public_models = credentials.get("use_public_models", "false") == "true"
params = {
"self": str(not use_public_models).lower(),
"page_size": "100",
}
if language is not None:
if "-" in language:
language = language.split("-")[0]
params["language"] = language
results = httpx.get(
f"{api_base}/model",
headers={"Authorization": f"Bearer {api_key}"},
params=params,
)
results.raise_for_status()
data = results.json()
return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]
def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> any:
"""
Invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: generator yielding audio chunks
"""
return self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def validate_credentials(
self, credentials: dict, user: Optional[str] = None
) -> None:
"""
Validate credentials for text2speech model
:param credentials: model credentials
:param user: unique user id
"""
try:
self.get_tts_model_voices(
None,
credentials={
"api_key": credentials["api_key"],
"api_base": credentials["api_base"],
# Disable public models will trigger a 403 error if user is not logged in
"use_public_models": "false",
},
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(
self, model: str, credentials: dict, content_text: str, voice: str
) -> any:
"""
Invoke streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""
try:
word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(
content_text, max_length=word_limit
)
else:
sentences = [content_text.strip()]
for i in range(len(sentences)):
yield from self._tts_invoke_streaming_sentence(
credentials=credentials, content_text=sentences[i], voice=voice
)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _tts_invoke_streaming_sentence(
self, credentials: dict, content_text: str, voice: Optional[str] = None
) -> any:
"""
Invoke streaming text2speech model
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""
api_key = credentials.get("api_key")
api_url = credentials.get("api_base", "https://api.fish.audio")
latency = credentials.get("latency")
if not api_key:
raise InvokeBadRequestError("API key is required")
with httpx.stream(
"POST",
api_url + "/v1/tts",
json={
"text": content_text,
"reference_id": voice,
"latency": latency
},
headers={
"Authorization": f"Bearer {api_key}",
},
timeout=None,
) as response:
if response.status_code != 200:
raise InvokeBadRequestError(
f"Error: {response.status_code} - {response.text}"
)
yield from response.iter_bytes()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeBadRequestError: [
httpx.HTTPStatusError,
],
}

View File

@ -0,0 +1,5 @@
model: tts-default
model_type: tts
model_properties:
word_limit: 1000
audio_type: 'mp3'

View File

@ -416,11 +416,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [

View File

@ -86,7 +86,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
Calculate num tokens for minimax model
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
to caculate the num tokens, so we use str() to convert the prompt to string
to calculate the num tokens, so we use str() to convert the prompt to string
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
therefore, we use gpt2 tokenizer instead

View File

@ -10,6 +10,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _update_endpoint_url(self, credentials: dict):
credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
return credentials

View File

@ -54,7 +54,6 @@ class NvidiaRerankModel(RerankModel):
"query": {"text": query},
"passages": [{"text": doc} for doc in docs],
}
session = requests.Session()
response = session.post(invoke_url, headers=headers, json=payload)
response.raise_for_status()
@ -71,7 +70,10 @@ class NvidiaRerankModel(RerankModel):
)
rerank_documents.append(rerank_document)
if rerank_documents:
rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True)
if top_n:
rerank_documents = rerank_documents[:top_n]
return RerankResult(model=model, docs=rerank_documents)
except requests.HTTPError as e:
raise InvokeServerUnavailableError(str(e))

View File

@ -0,0 +1 @@
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

After

Width:  |  Height:  |  Size: 874 B

View File

@ -0,0 +1 @@
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

After

Width:  |  Height:  |  Size: 874 B

View File

@ -0,0 +1,52 @@
model: cohere.command-r-16k
label:
en_US: cohere.command-r-16k v1.2
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 1.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: 0
max: 1
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: 0
max: 1
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 4000
pricing:
input: '0.004'
output: '0.004'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,52 @@
model: cohere.command-r-plus
label:
en_US: cohere.command-r-plus v1.2
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 1.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: 0
max: 1
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: 0
max: 1
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 4000
pricing:
input: '0.0219'
output: '0.0219'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,461 @@
import base64
import copy
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
import oci
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.command-r-plus",
"servingType": "ON_DEMAND"
},
"chatRequest": {
"apiFormat": "COHERE",
#"preambleOverride": "You are a helpful assistant.",
#"message": "Hello!",
#"chatHistory": [],
"maxTokens": 600,
"isStream": False,
"frequencyPenalty": 0,
"presencePenalty": 0,
"temperature": 1,
"topP": 0.75
}
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}
class OCILargeLanguageModel(LargeLanguageModel):
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
_supported_models = {
"meta.llama-3-70b-instruct": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"cohere.command-r-16k": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"cohere.command-r-plus": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
}
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["stream_tool_call"] if stream else feature["tool_call"]
def _is_multimodal_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["multimodal"]
def _is_system_prompt_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["system"]
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
#print("model"+"*"*20)
#print(model)
#print("credentials"+"*"*20)
#print(credentials)
#print("model_parameters"+"*"*20)
#print(model_parameters)
#print("prompt_messages"+"*"*200)
#print(prompt_messages)
#print("tools"+"*"*20)
#print(tools)
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
return len(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
return text.rstrip()
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
# Setup basic variables
# Auth Config
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# config_kwargs = model_parameters.copy()
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
# if stop:
# config_kwargs["stop_sequences"] = stop
# initialize client
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
oci_config["region"] = config_items[3]
oci_config["compartment_id"] = config_items[4]
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
compartment_id = oci_config["compartment_id"]
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
# call embedding model
request_args = copy.deepcopy(request_template)
request_args["compartmentId"] = compartment_id
request_args["servingMode"]["modelId"] = model
chat_history = []
system_prompts = []
#if "meta.llama" in model:
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
request_args["chatRequest"].update(model_parameters)
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
presence_penalty = model_parameters.get("presencePenalty", 0)
if frequency_penalty > 0 and presence_penalty > 0:
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
# for msg in prompt_messages: # makes message roles strictly alternating
# content = self._format_message_to_glm_content(msg)
# if history and history[-1]["role"] == content["role"]:
# history[-1]["parts"].extend(content["parts"])
# else:
# history.append(content)
# temporary not implement the tool call function
valid_value = self._is_tool_call_supported(model, stream)
if tools is not None and len(tools) > 0:
if not valid_value:
raise InvokeBadRequestError("Does not support function calling")
if model.startswith("cohere"):
#print("run cohere " * 10)
for message in prompt_messages[:-1]:
text = ""
if isinstance(message.content, str):
text = message.content
if isinstance(message, UserPromptMessage):
chat_history.append({"role": "USER", "message": text})
else:
chat_history.append({"role": "CHATBOT", "message": text})
if isinstance(message, SystemPromptMessage):
if isinstance(message.content, str):
system_prompts.append(message.content)
args = {"apiFormat": "COHERE",
"preambleOverride": ' '.join(system_prompts),
"message": prompt_messages[-1].content,
"chatHistory": chat_history, }
request_args["chatRequest"].update(args)
elif model.startswith("meta"):
#print("run meta " * 10)
meta_messages = []
for message in prompt_messages:
text = message.content
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
args = {"apiFormat": "GENERIC",
"messages": meta_messages,
"numGenerations": 1,
"topK": -1}
request_args["chatRequest"].update(args)
if stream:
request_args["chatRequest"]["isStream"] = True
#print("final request" + "|" * 20)
#print(request_args)
response = client.chat(request_args)
#print(vars(response))
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.data.chat_response.text
)
# calculate num tokens
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
index = -1
events = response.data.events()
for stream in events:
chunk = json.loads(stream.data)
#print(chunk)
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
#for chunk in response:
#for part in chunk.parts:
#if part.function_call:
# assistant_prompt_message.tool_calls = [
# AssistantPromptMessage.ToolCall(
# id=part.function_call.name,
# type='function',
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
# name=part.function_call.name,
# arguments=json.dumps(dict(part.function_call.args.items()))
# )
# )
# ]
if "finishReason" not in chunk:
assistant_prompt_message = AssistantPromptMessage(
content=''
)
if model.startswith("cohere"):
if chunk["text"]:
assistant_prompt_message.content += chunk["text"]
elif model.startswith("meta"):
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
index += 1
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
# calculate num tokens
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk["finishReason"]),
usage=usage
)
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nuser:"
ai_prompt = "\n\nmodel:"
content = message.content
if isinstance(content, list):
content = "".join(
c.data for c in content if c.type != PromptMessageContentType.IMAGE
)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
}

View File

@ -0,0 +1,51 @@
model: meta.llama-3-70b-instruct
label:
zh_Hans: meta.llama-3-70b-instruct
en_US: meta.llama-3-70b-instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 2.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: -2
max: 2
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: -2
max: 2
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 8000
pricing:
input: '0.015'
output: '0.015'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,34 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class OCIGENAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `cohere.command-r-plus` model for validate,
model_instance.validate_credentials(
model='cohere.command-r-plus',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,42 @@
provider: oci
label:
en_US: OCIGenerativeAI
description:
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from OCI
zh_Hans: 从 OCI 获取 API Key
url:
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
supported_model_types:
- llm
- text-embedding
#- rerank
configurate_methods:
- predefined-model
#- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: oci_config_content
label:
en_US: oci api key config file's content
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
- variable: oci_key_content
label:
en_US: oci api key file's content
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))

View File

@ -0,0 +1,5 @@
- cohere.embed-english-light-v2.0
- cohere.embed-english-light-v3.0
- cohere.embed-english-v3.0
- cohere.embed-multilingual-light-v3.0
- cohere.embed-multilingual-v3.0

View File

@ -0,0 +1,9 @@
model: cohere.embed-english-light-v2.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,9 @@
model: cohere.embed-english-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,9 @@
model: cohere.embed-english-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,9 @@
model: cohere.embed-multilingual-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,9 @@
model: cohere.embed-multilingual-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,242 @@
import base64
import copy
import time
from typing import Optional
import numpy as np
import oci
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.embed-english-light-v3.0",
"servingType": "ON_DEMAND"
},
"truncate": "NONE",
"inputs": [""]
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}
class OCITextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
# call embedding model
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=inputs[i: i + max_chunks]
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
characters = 0
for text in texts:
characters += len(text)
return characters
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
"""
Invoke embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return: embeddings and used tokens
"""
# oci
# initialize client
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
oci_config["region"] = config_items[3]
oci_config["compartment_id"] = config_items[4]
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
compartment_id = oci_config["compartment_id"]
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
# call embedding model
request_args = copy.deepcopy(request_template)
request_args["compartmentId"] = compartment_id
request_args["servingMode"]["modelId"] = model
request_args["inputs"] = texts
response = client.embed_text(request_args)
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}

View File

@ -89,7 +89,8 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
timeout=(10, 300),
options={"use_mmap": "true"}
)
response.raise_for_status() # Raise an exception for HTTP errors

View File

@ -552,7 +552,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not currect json_schema format: {json_schema}")
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:

View File

@ -1,17 +1,36 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import re
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
import boto3
from sagemaker import Predictor, serializers
from sagemaker.session import Session
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
"""
params:
predictor : Sagemaker Predictor
messages (List[Dict[str,Any]]): message list
messages = [
{"role": "system", "content":"please answer in Chinese"},
{"role": "user", "content": "who are you? what are you doing?"},
]
params (Dict[str,Any]): model parameters for LLM
stream (bool): False by default
response:
result of inference if stream is False
Iterator of Chunks if stream is True
"""
payload = {
"model" : params.get('model_name'),
"stop" : stop,
"messages": messages,
"stream" : stream,
"max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
"temperature" : params.get('temperature', 0.1),
"top_p" : params.get('top_p', 0.9),
}
if not stream:
response = predictor.predict(payload)
return response
else:
response_stream = predictor.predict_stream(payload)
return response_stream
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
sagemaker_sess : Any = None
predictor : Any = None
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: bytes) -> LLMResult:
"""
handle normal chat generate response
"""
resp_obj = json.loads(resp.decode('utf-8'))
resp_str = resp_obj.get('choices')[0].get('message').get('content')
if len(resp_str) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_prompt_message = AssistantPromptMessage(
content=resp_str,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
usage=usage,
message=assistant_prompt_message,
)
return response
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[bytes]) -> Generator:
"""
handle stream chat generate response
"""
full_response = ''
buffer = ""
for chunk_bytes in resp:
buffer += chunk_bytes.decode('utf-8')
last_idx = 0
for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
try:
data = json.loads(match.group(1).strip())
last_idx = match.span()[1]
if "content" in data["choices"][0]["delta"]:
chunk_content = data["choices"][0]["delta"]["content"]
assistant_prompt_message = AssistantPromptMessage(
content=chunk_content,
tool_calls=[]
)
if data["choices"][0]['finish_reason'] is not None:
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=data["choices"][0]['finish_reason'],
usage=usage
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message
),
)
full_response += chunk_content
except (json.JSONDecodeError, KeyError, IndexError) as e:
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
pass
buffer = buffer[last_idx:]
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
self.predictor = Predictor(
endpoint_name=credentials.get('sagemaker_endpoint'),
sagemaker_session=sagemaker_session,
serializer=serializers.JSONSerializer(),
)
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
if stream:
if tools and len(tools) > 0:
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
usage = self._calc_response_usage(model, credentials, 0, 0)
return self._handle_chat_stream_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
return self._handle_chat_generate_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return response
return message_dict
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
if is_completion_model:
return sum(tokens(str(message.content)) for message in messages)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
if key == "function_call":
for t_key, t_value in value.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
else:
num_tokens += tokens(str(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
return self._num_tokens_from_messages(prompt_messages, tools)
except Exception as e:
raise self._transform_invoke_error(e)
@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
pass
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
completion_type = LLMMode.value_of(credentials["mode"]).value
features = []

View File

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
Model class for SageMaker rerank model.
"""
sagemaker_client: Any = None

View File

@ -1,10 +1,11 @@
import logging
import uuid
from typing import IO, Any
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
'''
return s3_uri of this file
'''
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
s3_client.put_object(
Body=file.read(),
Bucket=bucket,
Key=s3_key,
ContentType='audio/mp3'
)
return s3_key
def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name, 'Key': object_key},
ExpiresIn=expiration)
except Exception as e:
print(f"Error generating presigned URL: {e}")
return None
return response

View File

@ -21,6 +21,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@ -45,14 +47,10 @@ model_credential_schema:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
zh_Hans: Chat
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
@ -61,6 +59,76 @@ model_credential_schema:
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: audio_s3_cache_bucket
show_on:
- variable: __model_type
value: speech2text
label:
zh_Hans: 音频缓存桶(s3 bucket)
en_US: audio cache bucket(s3 bucket)
type: text-input
required: true
placeholder:
zh_Hans: sagemaker-us-east-1-******207838
en_US: sagemaker-us-east-1-*******7838
- variable: audio_model_type
show_on:
- variable: __model_type
value: tts
label:
en_US: Audio model type
type: select
required: true
placeholder:
zh_Hans: 语音模型类型
en_US: Audio model type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 内置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: crosslingual clone voice
zh_Hans: 跨语种克隆音色
- value: InstructVoice
label:
en_US: Instruct voice
zh_Hans: 文字指令音色
- variable: prompt_audio
show_on:
- variable: __model_type
value: tts
label:
en_US: Mock Audio Source
type: text-input
required: false
placeholder:
zh_Hans: 被模仿的音色音频
en_US: source audio to be mocked
- variable: prompt_text
show_on:
- variable: __model_type
value: tts
label:
en_US: Prompt Audio Text
type: text-input
required: false
placeholder:
zh_Hans: 模仿音色的对应文本
en_US: text for the mocked source audio
- variable: instruct_text
show_on:
- variable: __model_type
value: tts
label:
en_US: instruct text for speaker
type: text-input
required: false
- variable: aws_access_key_id
required: false
label:

View File

@ -0,0 +1,142 @@
import json
import logging
from typing import IO, Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
logger = logging.getLogger(__name__)
class SageMakerSpeech2TextModel(Speech2TextModel):
"""
Model class for Xinference speech to text model.
"""
sagemaker_client: Any = None
s3_client : Any = None
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
asr_text = None
try:
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
s3_prefix='dify/speech2text/'
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
bucket = credentials.get('audio_s3_cache_bucket')
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {
"audio_s3_presign_uri" : s3_presign_url
}
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(payload),
ContentType="application/json"
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
asr_text = json_obj['text']
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
return asr_text
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,287 @@
import concurrent.futures
import copy
import json
import logging
from enum import Enum
from typing import Any, Optional
import boto3
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.tts_model import TTSModel
logger = logging.getLogger(__name__)
class TTSModelType(Enum):
PresetVoice = "PresetVoice"
CloneVoice = "CloneVoice"
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerText2SpeechModel(TTSModel):
sagemaker_client: Any = None
s3_client : Any = None
comprehend_client : Any = None
def __init__(self):
# preset voices, need support custom voice
self.model_voices = {
'__default': {
'all': [
{'name': 'Default', 'value': 'default'},
]
},
'CosyVoice': {
'zh-Hans': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'zh-Hant': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'en-US': [
{'name': '英文男', 'value': '英文男'},
{'name': '英文女', 'value': '英文女'},
],
'ja-JP': [
{'name': '日语男', 'value': '日语男'},
],
'ko-KR': [
{'name': '韩语女', 'value': '韩语女'},
]
}
}
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.comprehend_client = boto3.client('comprehend',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
model_type = credentials.get('audio_model_type', 'PresetVoice')
prompt_text = credentials.get('prompt_text')
prompt_audio = credentials.get('prompt_audio')
instruct_text = credentials.get('instruct_text')
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
payload = self._build_tts_payload(
model_type,
content_text,
voice,
prompt_text,
prompt_audio,
instruct_text
)
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={},
parameter_rules=[]
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
return ""
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
return 15
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
return "mp3"
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
return 5
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
audio_model_name = 'CosyVoice'
for key, voices in self.model_voices.items():
if key in audio_model_name:
if language and language in voices:
return voices[language]
elif 'all' in voices:
return voices['all']
return self.model_voices['__default']['all']
def _invoke_sagemaker(self, payload:dict, endpoint:str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
return json_obj
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
lang_tag = ''
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
lang_tag = payload.pop('lang_tag')
word_limit = self._get_model_word_limit(model='', credentials={})
content_text = payload.get("tts_text")
if len(content_text) > word_limit:
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
len_sent = len(sentences)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
for idx in range(len_sent):
payloads[idx]["tts_text"] = sentences[idx]
futures = [ executor.submit(
self._invoke_sagemaker,
payload=payload,
endpoint=sagemaker_endpoint,
)
for payload in payloads]
for index, future in enumerate(futures):
resp = future.result()
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
else:
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -19,27 +19,25 @@ class SparkLLMClient:
endpoint = 'chat'
if api_domain:
domain = api_domain
if model == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark-1.5': {
'spark-lite': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-3': {
'spark-pro': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-3.5': {
'spark-pro-128k': {
'version': 'pro-128k',
'chat_domain': 'pro-128k'
},
'spark-max': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
},
'spark-4': {
'spark-4.0-ultra': {
'version': 'v4.0',
'chat_domain': '4.0Ultra'
}
@ -48,7 +46,12 @@ class SparkLLMClient:
api_version = model_api_configs[model]['version']
self.chat_domain = model_api_configs[model]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
if model == 'spark-pro-128k':
self.api_base = f"wss://{domain}/{endpoint}/{api_version}"
else:
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,

View File

@ -1,3 +1,8 @@
- spark-4.0-ultra
- spark-max
- spark-pro-128k
- spark-pro
- spark-lite
- spark-4
- spark-3.5
- spark-3

View File

@ -1,4 +1,5 @@
model: spark-1.5
deprecated: true
label:
en_US: Spark V1.5
model_type: llm

View File

@ -1,4 +1,5 @@
model: spark-3.5
deprecated: true
label:
en_US: Spark V3.5
model_type: llm

View File

@ -1,4 +1,5 @@
model: spark-3
deprecated: true
label:
en_US: Spark V3.0
model_type: llm

View File

@ -0,0 +1,42 @@
model: spark-4.0-ultra
label:
en_US: Spark 4.0 Ultra
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 4096
min: 1
max: 8192
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: Maximum length of tokens for the model response.
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false
- name: show_ref_label
label:
zh_Hans: 联网检索
en_US: web search
type: boolean
default: false
help:
zh_Hans: 该参数仅4.0 Ultra版本支持当设置为true时如果输入内容触发联网检索插件会先返回检索信源列表然后再返回星火回复结果否则仅返回星火回复结果
en_US: The parameter is only supported in the 4.0 Ultra version. When set to true, if the input triggers the online search plugin, it will first return a list of search sources and then return the Spark response. Otherwise, it will only return the Spark response.

View File

@ -1,4 +1,5 @@
model: spark-4
deprecated: true
label:
en_US: Spark V4.0
model_type: llm

View File

@ -0,0 +1,33 @@
model: spark-lite
label:
en_US: Spark Lite
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 4096
min: 1
max: 4096
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: Maximum length of tokens for the model response.
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false

View File

@ -0,0 +1,33 @@
model: spark-max
label:
en_US: Spark Max
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 4096
min: 1
max: 8192
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: Maximum length of tokens for the model response.
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false

View File

@ -0,0 +1,33 @@
model: spark-pro-128k
label:
en_US: Spark Pro-128K
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 4096
min: 1
max: 4096
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: Maximum length of tokens for the model response.
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false

View File

@ -0,0 +1,33 @@
model: spark-pro
label:
en_US: Spark Pro
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 4096
min: 1
max: 8192
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: Maximum length of tokens for the model response.
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false

View File

@ -67,7 +67,7 @@ class FlashRecognitionRequest:
class FlashRecognizer:
"""
reponse:
response:
request_id string
status Integer
message String
@ -132,9 +132,9 @@ class FlashRecognizer:
signstr = self._format_sign_string(query)
signature = self._sign(signstr, secret_key)
header["Authorization"] = signature
requrl = "https://"
requrl += signstr[4::]
return requrl
req_url = "https://"
req_url += signstr[4::]
return req_url
def _create_query_arr(self, req):
return {

View File

@ -17,7 +17,6 @@ from dashscope.common.error import (
UnsupportedModel,
)
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -64,88 +63,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
# invoke model without code wrapper
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
-> LLMResult | Generator:
"""
Wrapper for code block mode
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} with markdown codeblocks.")
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=response
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""

View File

@ -24,7 +24,7 @@ parameter_rules:
type: int
default: 2000
min: 1
max: 2000
max: 6000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.

Some files were not shown because too many files have changed in this diff Show More