mirror of https://github.com/langgenius/dify.git
Merge main
This commit is contained in:
commit
9c7bcd5abc
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
|
@ -107,7 +107,7 @@ jobs:
|
|||
dev/**
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
2
LICENSE
2
LICENSE
|
|
@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
|
|||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ DB_DATABASE=dify
|
|||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
|
@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
|||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
|
|
@ -72,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
|
|||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
|
||||
# OCI Storage configuration
|
||||
OCI_ENDPOINT=your-endpoint
|
||||
OCI_BUCKET_NAME=your-bucket-name
|
||||
|
|
@ -79,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
|
|||
OCI_SECRET_KEY=your-secret-key
|
||||
OCI_REGION=your-region
|
||||
|
||||
# Volcengine tos Storage configuration
|
||||
VOLCENGINE_TOS_ENDPOINT=your-endpoint
|
||||
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
|
||||
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
|
@ -100,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
|
|||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
|
@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
description="endpoint URL of code execution service",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
|
|
@ -230,20 +230,17 @@ class HttpConfig(BaseSettings):
|
|||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=300,
|
||||
)
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
|
|
@ -431,7 +428,7 @@ class MailConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
|
|
@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
|
|||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
|
|
@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
|
|||
default=None,
|
||||
)
|
||||
|
||||
CELERY_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel master name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
|
|
@ -184,6 +201,8 @@ class MiddlewareConfig(
|
|||
AzureBlobStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
|
@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
|
|||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SENTINELS: Optional[str] = Field(
|
||||
description="Redis Sentinel nodes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel service name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
|
||||
description="Redis Sentinel username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Redis Sentinel password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,3 +38,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
|||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
default=None,
|
||||
)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
default=None,
|
||||
)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
|
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
|
|||
Milvus configs
|
||||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
MILVUS_URI: Optional[str] = Field(
|
||||
description="Milvus uri",
|
||||
default="http://127.0.0.1:19530",
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
MILVUS_TOKEN: Optional[str] = Field(
|
||||
description="Milvus token",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
|
|
@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.7.2",
|
||||
default="0.7.3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -174,6 +174,7 @@ class AppApi(Resource):
|
|||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
|
|
|||
|
|
@ -201,7 +201,11 @@ class ChatConversationApi(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
|
|
@ -210,7 +214,11 @@ class ChatConversationApi(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def parse_app_site_args():
|
|||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -68,6 +69,7 @@ class AppSite(Resource):
|
|||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
|
|
|
|||
|
|
@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
|
|||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
|
@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
if args["indexing_technique"] == "high_quality":
|
||||
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
|
|||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is not completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
|
|
@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
|
|||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class AppSiteApi(WebApiResource):
|
|||
"default_language": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
app_fields = {
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class DatasetConfigManager:
|
|||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class Extensible:
|
|||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
position_map[extension_name] = position
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def is_filtered(
|
|||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
Check if the object should be filtered out.
|
||||
Overall logic: exclude > include > pin
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ from configs import dify_config
|
|||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
|
@ -255,11 +253,8 @@ class IndexingRunner:
|
|||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
|
|
@ -286,54 +281,22 @@ class IndexingRunner:
|
|||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
tokens += embedding_model_instance.get_text_embedding_num_tokens(
|
||||
texts=[self.filter_string(document.page_content)]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
|
@ -531,7 +494,7 @@ class IndexingRunner:
|
|||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ Here is a task description for which I would like you to create a high-quality p
|
|||
{{TASK_DESCRIPTION}}
|
||||
</task_description>
|
||||
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
|
||||
- Do not inlcude <input> or <output> section and variables in the prompt, assume user will add them at their own will.
|
||||
- Do not include <input> or <output> section and variables in the prompt, assume user will add them at their own will.
|
||||
- Clear instructions for the AI that will be using this prompt, demarcated with <instructions> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
|
||||
- Relevant examples if needed to clarify the task further, demarcated with <example> tags. Do not include variables in the prompt. Give three pairs of input and output examples.
|
||||
- Include other relevant sections demarcated with appropriate XML tags like <examples>, <instructions>.
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@
|
|||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `language` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
|
|
@ -150,7 +150,7 @@
|
|||
|
||||
- `input` (float) Input price, i.e., Prompt price
|
||||
- `output` (float) Output price, i.e., returned content price
|
||||
- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `currency` (string) Currency unit
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
|
|
|||
|
|
@ -33,6 +33,22 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||
'max': 1.0,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
'label': {
|
||||
'en_US': 'Top K',
|
||||
'zh_Hans': 'Top K',
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.',
|
||||
'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 50,
|
||||
'min': 1,
|
||||
'max': 100,
|
||||
'precision': 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Presence Penalty',
|
||||
|
|
|
|||
|
|
@ -85,12 +85,13 @@ class ModelFeature(Enum):
|
|||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
|
||||
|
||||
class DefaultParameterName(Enum):
|
||||
class DefaultParameterName(str, Enum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
Model class for TTS model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ class AnthropicProvider(ModelProvider):
|
|||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `claude-instant-1` model for validate,
|
||||
# Use `claude-3-opus-20240229` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
model='claude-3-opus-20240229',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -33,3 +33,4 @@ pricing:
|
|||
output: '5.51'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
|
|
|||
|
|
@ -637,7 +637,19 @@ LLM_BASE_MODELS = [
|
|||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
options=['text', 'json_object', 'json_schema']
|
||||
),
|
||||
ParameterRule(
|
||||
name='json_schema',
|
||||
label=I18nObject(
|
||||
en_US='JSON Schema'
|
||||
),
|
||||
type='text',
|
||||
help=I18nObject(
|
||||
zh_Hans='设置返回的json schema,llm将按照它返回',
|
||||
en_US='Set a response json schema will ensure LLM to adhere it.'
|
||||
),
|
||||
required=False
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
|
|
@ -800,6 +812,94 @@ LLM_BASE_MODELS = [
|
|||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4o-2024-08-06',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object', 'json_schema']
|
||||
),
|
||||
ParameterRule(
|
||||
name='json_schema',
|
||||
label=I18nObject(
|
||||
en_US='JSON Schema'
|
||||
),
|
||||
type='text',
|
||||
help=I18nObject(
|
||||
zh_Hans='设置返回的json schema,llm将按照它返回',
|
||||
en_US='Set a response json schema will ensure LLM to adhere it.'
|
||||
),
|
||||
required=False
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo',
|
||||
entity=AIModelEntity(
|
||||
|
|
|
|||
|
|
@ -138,6 +138,12 @@ model_credential_schema:
|
|||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-08-06
|
||||
value: gpt-4o-2024-08-06
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
|
@ -276,12 +277,18 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_object":
|
||||
response_format = {"type": "json_object"}
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
|
||||
model_parameters["response_format"] = response_format
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,3 @@ provider_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: secret_key
|
||||
label:
|
||||
en_US: Secret Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret Key
|
||||
en_US: Enter your Secret Key
|
||||
|
|
|
|||
|
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
|
|
|||
|
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
|
|
|||
|
|
@ -4,36 +4,32 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 192000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
|||
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 128000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
|||
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
|||
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
|
|
@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||
)
|
||||
|
||||
|
||||
class BaichuanMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
# Baichuan does not have system message
|
||||
_SYSTEM = 'system'
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
'role': self.role,
|
||||
'content': self.content,
|
||||
}
|
||||
|
||||
def __init__(self, content: str, role: str = 'user') -> None:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class BaichuanModel:
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def _model_mapping(self, model: str) -> str:
|
||||
@property
|
||||
def _model_mapping(self) -> dict:
|
||||
return {
|
||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
||||
'baichuan2-53b': 'Baichuan2-53B',
|
||||
'baichuan3-turbo': 'Baichuan3-Turbo',
|
||||
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
||||
'baichuan4': 'Baichuan4',
|
||||
}[model]
|
||||
"baichuan2-turbo": "Baichuan2-Turbo",
|
||||
"baichuan3-turbo": "Baichuan3-Turbo",
|
||||
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
|
||||
"baichuan4": "Baichuan4",
|
||||
}
|
||||
|
||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
@property
|
||||
def request_headers(self) -> dict[str, Any]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
def _build_parameters(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if model in self._model_mapping.keys():
|
||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||
# we need to rename it to res_format to get its value
|
||||
if parameters.get("res_format") == "json_object":
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
|
||||
return message
|
||||
|
||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode('utf-8')
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith('data:'):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == '[DONE]':
|
||||
return
|
||||
choices = data.get('choices', [])
|
||||
# save stop reason temporarily
|
||||
stop_reason = ''
|
||||
for choice in choices:
|
||||
if choice.get('finish_reason'):
|
||||
stop_reason = choice['finish_reason']
|
||||
if tools or parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"] = []
|
||||
|
||||
if len(choice['delta']['content']) == 0:
|
||||
continue
|
||||
yield BaichuanMessage(**choice['delta'])
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if 'usage' in data:
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
message.usage = {
|
||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
||||
'completion_tokens': data['usage']['completion_tokens'],
|
||||
'total_tokens': data['usage']['total_tokens'],
|
||||
}
|
||||
message.stop_reason = stop_reason
|
||||
yield message
|
||||
|
||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any]) \
|
||||
-> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
||||
# check if the latest message is a user message
|
||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
||||
prompt_messages[-1]['content'] += message.content
|
||||
else:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': BaichuanMessage.Role.USER.value,
|
||||
})
|
||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': message.role,
|
||||
})
|
||||
# [baichuan] frequency_penalty must be between 1 and 2
|
||||
if 'frequency_penalty' in parameters:
|
||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
||||
parameters['frequency_penalty'] = 1
|
||||
# with_search_enhance is deprecated, use web_search instead
|
||||
if parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "web_search",
|
||||
"web_search": {"enable": True},
|
||||
}
|
||||
)
|
||||
if tools:
|
||||
for tool in tools:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# turbo api accepts flat parameters
|
||||
return {
|
||||
'model': self._model_mapping(model),
|
||||
'stream': stream,
|
||||
'messages': prompt_messages,
|
||||
"model": self._model_mapping.get(model),
|
||||
"stream": stream,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
# there is no secret key for turbo api
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
return md5(input_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any], timeout: int) \
|
||||
-> Union[Generator, BaichuanMessage]:
|
||||
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
|
||||
if model in self._model_mapping.keys():
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
try:
|
||||
data = self._build_parameters(model, stream, messages, parameters)
|
||||
headers = self._build_headers(model, data)
|
||||
except KeyError:
|
||||
raise InternalServerError(f"Failed to build parameters for model: {model}")
|
||||
|
||||
data = self._build_parameters(model, stream, messages, parameters, tools)
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
headers=self.request_headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
err = resp["error"]["type"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InternalServerError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif 'rate' in err:
|
||||
elif err == "invalid_request_error":
|
||||
raise BadRequestError(msg)
|
||||
elif "rate" in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif 'internal' in err:
|
||||
elif "internal" in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
elif err == "api_key_empty":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_generate_response(response)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
return self._handle_chat_generate_response(response)
|
||||
return response.json()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from collections.abc import Generator
|
||||
import json
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
|
|
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
|
|
@ -32,20 +37,41 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||
)
|
||||
|
||||
|
||||
class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
class BaichuanLanguageModel(LargeLanguageModel):
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
def tokens(text: str):
|
||||
|
|
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
|
|
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
|
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
# ping
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
|
||||
try:
|
||||
instance.generate(model=model, stream=False, messages=[
|
||||
BaichuanMessage(content='ping', role='user')
|
||||
], parameters={
|
||||
'max_tokens': 1,
|
||||
}, timeout=60)
|
||||
instance.generate(
|
||||
model=model,
|
||||
stream=False,
|
||||
messages=[{"content": "ping", "role": "user"}],
|
||||
parameters={
|
||||
"max_tokens": 1,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stream: bool = True,
|
||||
) -> LLMResult | Generator:
|
||||
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
|
||||
# convert prompt messages to baichuan messages
|
||||
messages = [
|
||||
BaichuanMessage(
|
||||
content=message.content if isinstance(message.content, str) else ''.join([
|
||||
content.data for content in message.content
|
||||
]),
|
||||
role=message.role.value
|
||||
) for message in prompt_messages
|
||||
]
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
|
||||
# invoke model
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
||||
timeout=60)
|
||||
response = instance.generate(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
parameters=model_parameters,
|
||||
timeout=60,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
) -> LLMResult:
|
||||
choices = response.get("choices", [])
|
||||
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
||||
if choices and choices[0]["finish_reason"] == "tool_calls":
|
||||
for choice in choices:
|
||||
for tool_call in choice["message"]["tool_calls"]:
|
||||
tool = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.get("id", ""),
|
||||
type=tool_call.get("type", ""),
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.get("function", {}).get("name", ""),
|
||||
arguments=tool_call.get("function", {}).get("arguments", "")
|
||||
),
|
||||
)
|
||||
assistant_message.tool_calls.append(tool)
|
||||
else:
|
||||
for choice in choices:
|
||||
assistant_message.content += choice["message"]["content"]
|
||||
assistant_message.role = choice["message"]["role"]
|
||||
|
||||
usage = response.get("usage")
|
||||
if usage:
|
||||
# transform usage
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages([assistant_message])
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: BaichuanMessage) -> LLMResult:
|
||||
# convert baichuan message to llm result
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=response.usage['prompt_tokens'],
|
||||
completion_tokens=response.usage['completion_tokens'])
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=response.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=message.usage['prompt_tokens'],
|
||||
completion_tokens=message.usage['completion_tokens'])
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Iterator,
|
||||
) -> Generator:
|
||||
for line in response:
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == "[DONE]":
|
||||
return
|
||||
choices = data.get("choices", [])
|
||||
|
||||
stop_reason = ""
|
||||
for choice in choices:
|
||||
if choice.get("finish_reason"):
|
||||
stop_reason = choice["finish_reason"]
|
||||
|
||||
if len(choice["delta"]["content"]) == 0:
|
||||
continue
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
content=choice["delta"]["content"], tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if "usage" in data:
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=data["usage"]["prompt_tokens"],
|
||||
completion_tokens=data["usage"]["completion_tokens"],
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
message=AssistantPromptMessage(content="", tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||
token_usage = 0
|
||||
|
||||
for chunk in chunks:
|
||||
# embeding chunk
|
||||
# embedding chunk
|
||||
chunk_embeddings, chunk_usage = self.embedding(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
|
|
|||
|
|
@ -793,11 +793,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
|
|
|
|||
|
|
@ -130,11 +130,11 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
||||
|
After Width: | Height: | Size: 3.4 KiB |
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
||||
|
After Width: | Height: | Size: 3.4 KiB |
|
|
@ -0,0 +1,28 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FishAudioProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
For debugging purposes, this method now always passes validation.
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TTS)
|
||||
model_instance.validate_credentials(
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
provider: fishaudio
|
||||
label:
|
||||
en_US: Fish Audio
|
||||
description:
|
||||
en_US: Models provided by Fish Audio, currently only support TTS.
|
||||
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
|
||||
icon_small:
|
||||
en_US: fishaudio_s_en.svg
|
||||
icon_large:
|
||||
en_US: fishaudio_l_en.svg
|
||||
background: "#E5E7EB"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Fish Audio
|
||||
zh_Hans: 从 Fish Audio 获取你的 API Key
|
||||
url:
|
||||
en_US: https://fish.audio/go-api/
|
||||
supported_model_types:
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.fish.audio
|
||||
placeholder:
|
||||
en_US: Enter your API URL
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
- variable: use_public_models
|
||||
label:
|
||||
en_US: Use Public Models
|
||||
type: select
|
||||
required: false
|
||||
default: "false"
|
||||
placeholder:
|
||||
en_US: Toggle to use public models
|
||||
zh_Hans: 切换以使用公共模型
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: Allow Public Models
|
||||
zh_Hans: 使用公共模型
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: Private Models Only
|
||||
zh_Hans: 仅使用私有模型
|
||||
- variable: latency
|
||||
label:
|
||||
en_US: Latency
|
||||
type: select
|
||||
required: false
|
||||
default: "normal"
|
||||
placeholder:
|
||||
en_US: Toggle to choice latency
|
||||
zh_Hans: 切换以调整延迟
|
||||
options:
|
||||
- value: "balanced"
|
||||
label:
|
||||
en_US: Low (may affect quality)
|
||||
zh_Hans: 低延迟 (可能降低质量)
|
||||
- value: "normal"
|
||||
label:
|
||||
en_US: Normal
|
||||
zh_Hans: 标准
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
|
||||
class FishAudioText2SpeechModel(TTSModel):
|
||||
"""
|
||||
Model class for Fish.audio Text to Speech model.
|
||||
"""
|
||||
|
||||
def get_tts_model_voices(
|
||||
self, model: str, credentials: dict, language: Optional[str] = None
|
||||
) -> list:
|
||||
api_base = credentials.get("api_base", "https://api.fish.audio")
|
||||
api_key = credentials.get("api_key")
|
||||
use_public_models = credentials.get("use_public_models", "false") == "true"
|
||||
|
||||
params = {
|
||||
"self": str(not use_public_models).lower(),
|
||||
"page_size": "100",
|
||||
}
|
||||
|
||||
if language is not None:
|
||||
if "-" in language:
|
||||
language = language.split("-")[0]
|
||||
params["language"] = language
|
||||
|
||||
results = httpx.get(
|
||||
f"{api_base}/model",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
results.raise_for_status()
|
||||
data = results.json()
|
||||
|
||||
return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> any:
|
||||
"""
|
||||
Invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
|
||||
return self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict, user: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Validate credentials for text2speech model
|
||||
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
"""
|
||||
|
||||
try:
|
||||
self.get_tts_model_voices(
|
||||
None,
|
||||
credentials={
|
||||
"api_key": credentials["api_key"],
|
||||
"api_base": credentials["api_base"],
|
||||
# Disable public models will trigger a 403 error if user is not logged in
|
||||
"use_public_models": "false",
|
||||
},
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(
|
||||
self, model: str, credentials: dict, content_text: str, voice: str
|
||||
) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: ID of the reference audio (if any)
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
|
||||
try:
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(
|
||||
content_text, max_length=word_limit
|
||||
)
|
||||
else:
|
||||
sentences = [content_text.strip()]
|
||||
|
||||
for i in range(len(sentences)):
|
||||
yield from self._tts_invoke_streaming_sentence(
|
||||
credentials=credentials, content_text=sentences[i], voice=voice
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming_sentence(
|
||||
self, credentials: dict, content_text: str, voice: Optional[str] = None
|
||||
) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: ID of the reference audio (if any)
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
api_key = credentials.get("api_key")
|
||||
api_url = credentials.get("api_base", "https://api.fish.audio")
|
||||
latency = credentials.get("latency")
|
||||
|
||||
if not api_key:
|
||||
raise InvokeBadRequestError("API key is required")
|
||||
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
api_url + "/v1/tts",
|
||||
json={
|
||||
"text": content_text,
|
||||
"reference_id": voice,
|
||||
"latency": latency
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(
|
||||
f"Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
yield from response.iter_bytes()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeBadRequestError: [
|
||||
httpx.HTTPStatusError,
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
model: tts-default
|
||||
model_type: tts
|
||||
model_properties:
|
||||
word_limit: 1000
|
||||
audio_type: 'mp3'
|
||||
|
|
@ -416,11 +416,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||
Calculate num tokens for minimax model
|
||||
|
||||
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
||||
to caculate the num tokens, so we use str() to convert the prompt to string
|
||||
to calculate the num tokens, so we use str() to convert the prompt to string
|
||||
|
||||
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
||||
therefore, we use gpt2 tokenizer instead
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
|||
class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _update_endpoint_url(self, credentials: dict):
|
||||
|
||||
credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
|
||||
credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
|
||||
return credentials
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ class NvidiaRerankModel(RerankModel):
|
|||
"query": {"text": query},
|
||||
"passages": [{"text": doc} for doc in docs],
|
||||
}
|
||||
|
||||
session = requests.Session()
|
||||
response = session.post(invoke_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
|
@ -71,7 +70,10 @@ class NvidiaRerankModel(RerankModel):
|
|||
)
|
||||
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
if rerank_documents:
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True)
|
||||
if top_n:
|
||||
rerank_documents = rerank_documents[:top_n]
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except requests.HTTPError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
||||
|
After Width: | Height: | Size: 874 B |
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
||||
|
After Width: | Height: | Size: 874 B |
|
|
@ -0,0 +1,52 @@
|
|||
model: cohere.command-r-16k
|
||||
label:
|
||||
en_US: cohere.command-r-16k v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
model: cohere.command-r-plus
|
||||
label:
|
||||
en_US: cohere.command-r-plus v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0219'
|
||||
output: '0.0219'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,461 @@
|
|||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import oci
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.command-r-plus",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"chatRequest": {
|
||||
"apiFormat": "COHERE",
|
||||
#"preambleOverride": "You are a helpful assistant.",
|
||||
#"message": "Hello!",
|
||||
#"chatHistory": [],
|
||||
"maxTokens": 600,
|
||||
"isStream": False,
|
||||
"frequencyPenalty": 0,
|
||||
"presencePenalty": 0,
|
||||
"temperature": 1,
|
||||
"topP": 0.75
|
||||
}
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
|
||||
class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||
_supported_models = {
|
||||
"meta.llama-3-70b-instruct": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": False,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-16k": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-plus": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["stream_tool_call"] if stream else feature["tool_call"]
|
||||
|
||||
def _is_multimodal_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["multimodal"]
|
||||
|
||||
def _is_system_prompt_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["system"]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
#print("model"+"*"*20)
|
||||
#print(model)
|
||||
#print("credentials"+"*"*20)
|
||||
#print(credentials)
|
||||
#print("model_parameters"+"*"*20)
|
||||
#print(model_parameters)
|
||||
#print("prompt_messages"+"*"*200)
|
||||
#print(prompt_messages)
|
||||
#print("tools"+"*"*20)
|
||||
#print(tools)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return len(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# Setup basic variables
|
||||
# Auth Config
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# config_kwargs = model_parameters.copy()
|
||||
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||
# if stop:
|
||||
# config_kwargs["stop_sequences"] = stop
|
||||
|
||||
# initialize client
|
||||
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
|
||||
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
|
||||
chat_history = []
|
||||
system_prompts = []
|
||||
#if "meta.llama" in model:
|
||||
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
|
||||
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
|
||||
request_args["chatRequest"].update(model_parameters)
|
||||
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
|
||||
presence_penalty = model_parameters.get("presencePenalty", 0)
|
||||
if frequency_penalty > 0 and presence_penalty > 0:
|
||||
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
|
||||
|
||||
# for msg in prompt_messages: # makes message roles strictly alternating
|
||||
# content = self._format_message_to_glm_content(msg)
|
||||
# if history and history[-1]["role"] == content["role"]:
|
||||
# history[-1]["parts"].extend(content["parts"])
|
||||
# else:
|
||||
# history.append(content)
|
||||
|
||||
# temporary not implement the tool call function
|
||||
valid_value = self._is_tool_call_supported(model, stream)
|
||||
if tools is not None and len(tools) > 0:
|
||||
if not valid_value:
|
||||
raise InvokeBadRequestError("Does not support function calling")
|
||||
if model.startswith("cohere"):
|
||||
#print("run cohere " * 10)
|
||||
for message in prompt_messages[:-1]:
|
||||
text = ""
|
||||
if isinstance(message.content, str):
|
||||
text = message.content
|
||||
if isinstance(message, UserPromptMessage):
|
||||
chat_history.append({"role": "USER", "message": text})
|
||||
else:
|
||||
chat_history.append({"role": "CHATBOT", "message": text})
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
system_prompts.append(message.content)
|
||||
args = {"apiFormat": "COHERE",
|
||||
"preambleOverride": ' '.join(system_prompts),
|
||||
"message": prompt_messages[-1].content,
|
||||
"chatHistory": chat_history, }
|
||||
request_args["chatRequest"].update(args)
|
||||
elif model.startswith("meta"):
|
||||
#print("run meta " * 10)
|
||||
meta_messages = []
|
||||
for message in prompt_messages:
|
||||
text = message.content
|
||||
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
|
||||
args = {"apiFormat": "GENERIC",
|
||||
"messages": meta_messages,
|
||||
"numGenerations": 1,
|
||||
"topK": -1}
|
||||
request_args["chatRequest"].update(args)
|
||||
|
||||
if stream:
|
||||
request_args["chatRequest"]["isStream"] = True
|
||||
#print("final request" + "|" * 20)
|
||||
#print(request_args)
|
||||
response = client.chat(request_args)
|
||||
#print(vars(response))
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.data.chat_response.text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
events = response.data.events()
|
||||
for stream in events:
|
||||
chunk = json.loads(stream.data)
|
||||
#print(chunk)
|
||||
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||
|
||||
|
||||
|
||||
#for chunk in response:
|
||||
#for part in chunk.parts:
|
||||
#if part.function_call:
|
||||
# assistant_prompt_message.tool_calls = [
|
||||
# AssistantPromptMessage.ToolCall(
|
||||
# id=part.function_call.name,
|
||||
# type='function',
|
||||
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
# name=part.function_call.name,
|
||||
# arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
# )
|
||||
# )
|
||||
# ]
|
||||
|
||||
if "finishReason" not in chunk:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
if model.startswith("cohere"):
|
||||
if chunk["text"]:
|
||||
assistant_prompt_message.content += chunk["text"]
|
||||
elif model.startswith("meta"):
|
||||
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
|
||||
index += 1
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk["finishReason"]),
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
model: meta.llama-3-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta.llama-3-70b-instruct
|
||||
en_US: meta.llama-3-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 2.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.015'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCIGENAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `cohere.command-r-plus` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='cohere.command-r-plus',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
provider: oci
|
||||
label:
|
||||
en_US: OCIGenerativeAI
|
||||
description:
|
||||
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
|
||||
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from OCI
|
||||
zh_Hans: 从 OCI 获取 API Key
|
||||
url:
|
||||
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
#- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
#- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: oci_config_content
|
||||
label:
|
||||
en_US: oci api key config file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
- variable: oci_key_content
|
||||
label:
|
||||
en_US: oci api key file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
|
||||
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
- cohere.embed-english-light-v2.0
|
||||
- cohere.embed-english-light-v3.0
|
||||
- cohere.embed-english-v3.0
|
||||
- cohere.embed-multilingual-light-v3.0
|
||||
- cohere.embed-multilingual-v3.0
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-light-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-multilingual-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-multilingual-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.embed-english-light-v3.0",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"truncate": "NONE",
|
||||
"inputs": [""]
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=inputs[i: i + max_chunks]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
characters = 0
|
||||
for text in texts:
|
||||
characters += len(text)
|
||||
return characters
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
|
||||
# oci
|
||||
# initialize client
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
request_args["inputs"] = texts
|
||||
response = client.embed_text(request_args)
|
||||
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
|
@ -89,7 +89,8 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||
endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
timeout=(10, 300),
|
||||
options={"use_mmap": "true"}
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
|
|
|||
|
|
@ -552,7 +552,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not currect json_schema format: {json_schema}")
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,17 +1,36 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
import re
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
import boto3
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
|
|
@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
|
||||
"""
|
||||
params:
|
||||
predictor : Sagemaker Predictor
|
||||
messages (List[Dict[str,Any]]): message list。
|
||||
messages = [
|
||||
{"role": "system", "content":"please answer in Chinese"},
|
||||
{"role": "user", "content": "who are you? what are you doing?"},
|
||||
]
|
||||
params (Dict[str,Any]): model parameters for LLM。
|
||||
stream (bool): False by default。
|
||||
|
||||
response:
|
||||
result of inference if stream is False
|
||||
Iterator of Chunks if stream is True
|
||||
"""
|
||||
payload = {
|
||||
"model" : params.get('model_name'),
|
||||
"stop" : stop,
|
||||
"messages": messages,
|
||||
"stream" : stream,
|
||||
"max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
|
||||
"temperature" : params.get('temperature', 0.1),
|
||||
"top_p" : params.get('top_p', 0.9),
|
||||
}
|
||||
|
||||
if not stream:
|
||||
response = predictor.predict(payload)
|
||||
return response
|
||||
else:
|
||||
response_stream = predictor.predict_stream(payload)
|
||||
return response_stream
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_sess : Any = None
|
||||
predictor : Any = None
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: bytes) -> LLMResult:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
resp_obj = json.loads(resp.decode('utf-8'))
|
||||
resp_str = resp_obj.get('choices')[0].get('message').get('content')
|
||||
|
||||
if len(resp_str) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=resp_str,
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
usage=usage,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Iterator[bytes]) -> Generator:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
"""
|
||||
full_response = ''
|
||||
buffer = ""
|
||||
for chunk_bytes in resp:
|
||||
buffer += chunk_bytes.decode('utf-8')
|
||||
last_idx = 0
|
||||
for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
|
||||
try:
|
||||
data = json.loads(match.group(1).strip())
|
||||
last_idx = match.span()[1]
|
||||
|
||||
if "content" in data["choices"][0]["delta"]:
|
||||
chunk_content = data["choices"][0]["delta"]["content"]
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_content,
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
if data["choices"][0]['finish_reason'] is not None:
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response,
|
||||
tool_calls=[]
|
||||
)
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=data["choices"][0]['finish_reason'],
|
||||
usage=usage
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=assistant_prompt_message
|
||||
),
|
||||
)
|
||||
|
||||
full_response += chunk_content
|
||||
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
|
||||
pass
|
||||
|
||||
buffer = buffer[last_idx:]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
|
|
@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
|
|
@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
|
||||
self.predictor = Predictor(
|
||||
endpoint_name=credentials.get('sagemaker_endpoint'),
|
||||
sagemaker_session=sagemaker_session,
|
||||
serializer=serializers.JSONSerializer(),
|
||||
)
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
|
||||
response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
if stream:
|
||||
if tools and len(tools) > 0:
|
||||
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
return self._handle_chat_stream_response(model=model, credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
return self._handle_chat_generate_response(model=model, credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return response
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
||||
is_completion_model: bool = False) -> int:
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if is_completion_model:
|
||||
return sum(tokens(str(message.content)) for message in messages)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
if key == "function_call":
|
||||
for t_key, t_value in value.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
else:
|
||||
num_tokens += tokens(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
|
@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
|
|
@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
pass
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
|
@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
completion_type = LLMMode.value_of(credentials["mode"]).value
|
||||
|
||||
features = []
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
Model class for SageMaker rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
import uuid
|
||||
from typing import IO, Any
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
|
|
@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
|
|||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
|
||||
'''
|
||||
return s3_uri of this file
|
||||
'''
|
||||
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
|
||||
s3_client.put_object(
|
||||
Body=file.read(),
|
||||
Bucket=bucket,
|
||||
Key=s3_key,
|
||||
ContentType='audio/mp3'
|
||||
)
|
||||
return s3_key
|
||||
|
||||
def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
|
||||
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
|
||||
try:
|
||||
response = s3_client.generate_presigned_url('get_object',
|
||||
Params={'Bucket': bucket_name, 'Key': object_key},
|
||||
ExpiresIn=expiration)
|
||||
except Exception as e:
|
||||
print(f"Error generating presigned URL: {e}")
|
||||
return None
|
||||
|
||||
return response
|
||||
|
|
@ -21,6 +21,8 @@ supported_model_types:
|
|||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
|
|
@ -45,14 +47,10 @@ model_credential_schema:
|
|||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
zh_Hans: Chat
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
|
|
@ -61,6 +59,76 @@ model_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: audio_s3_cache_bucket
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
label:
|
||||
zh_Hans: 音频缓存桶(s3 bucket)
|
||||
en_US: audio cache bucket(s3 bucket)
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: sagemaker-us-east-1-******207838
|
||||
en_US: sagemaker-us-east-1-*******7838
|
||||
- variable: audio_model_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Audio model type
|
||||
type: select
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 语音模型类型
|
||||
en_US: Audio model type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 内置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: crosslingual clone voice
|
||||
zh_Hans: 跨语种克隆音色
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: Instruct voice
|
||||
zh_Hans: 文字指令音色
|
||||
- variable: prompt_audio
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Mock Audio Source
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 被模仿的音色音频
|
||||
en_US: source audio to be mocked
|
||||
- variable: prompt_text
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Prompt Audio Text
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 模仿音色的对应文本
|
||||
en_US: text for the mocked source audio
|
||||
- variable: instruct_text
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: instruct text for speaker
|
||||
type: text-input
|
||||
required: false
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,142 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import IO, Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||
"""
|
||||
Model class for Xinference speech to text model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
asr_text = None
|
||||
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
s3_prefix='dify/speech2text/'
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
bucket = credentials.get('audio_s3_cache_bucket')
|
||||
|
||||
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
|
||||
payload = {
|
||||
"audio_s3_presign_uri" : s3_presign_url
|
||||
}
|
||||
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json"
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
asr_text = json_obj['text']
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
return asr_text
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
@ -0,0 +1,287 @@
|
|||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
class SageMakerText2SpeechModel(TTSModel):
|
||||
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
|
||||
def __init__(self):
|
||||
# preset voices, need support custom voice
|
||||
self.model_voices = {
|
||||
'__default': {
|
||||
'all': [
|
||||
{'name': 'Default', 'value': 'default'},
|
||||
]
|
||||
},
|
||||
'CosyVoice': {
|
||||
'zh-Hans': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'zh-Hant': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'en-US': [
|
||||
{'name': '英文男', 'value': '英文男'},
|
||||
{'name': '英文女', 'value': '英文女'},
|
||||
],
|
||||
'ja-JP': [
|
||||
{'name': '日语男', 'value': '日语男'},
|
||||
],
|
||||
'ko-KR': [
|
||||
{'name': '韩语女', 'value': '韩语女'},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
|
||||
model_type = credentials.get('audio_model_type', 'PresetVoice')
|
||||
prompt_text = credentials.get('prompt_text')
|
||||
prompt_audio = credentials.get('prompt_audio')
|
||||
instruct_text = credentials.get('instruct_text')
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
payload = self._build_tts_payload(
|
||||
model_type,
|
||||
content_text,
|
||||
voice,
|
||||
prompt_text,
|
||||
prompt_audio,
|
||||
instruct_text
|
||||
)
|
||||
|
||||
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
return ""
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
return 15
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
return "mp3"
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
return 5
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
audio_model_name = 'CosyVoice'
|
||||
for key, voices in self.model_voices.items():
|
||||
if key in audio_model_name:
|
||||
if language and language in voices:
|
||||
return voices[language]
|
||||
elif 'all' in voices:
|
||||
return voices['all']
|
||||
|
||||
return self.model_voices['__default']['all']
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
lang_tag = ''
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
|
||||
lang_tag = payload.pop('lang_tag')
|
||||
|
||||
word_limit = self._get_model_word_limit(model='', credentials={})
|
||||
content_text = payload.get("tts_text")
|
||||
if len(content_text) > word_limit:
|
||||
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
|
||||
len_sent = len(sentences)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
|
||||
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
|
||||
for idx in range(len_sent):
|
||||
payloads[idx]["tts_text"] = sentences[idx]
|
||||
|
||||
futures = [ executor.submit(
|
||||
self._invoke_sagemaker,
|
||||
payload=payload,
|
||||
endpoint=sagemaker_endpoint,
|
||||
)
|
||||
for payload in payloads]
|
||||
|
||||
for index, future in enumerate(futures):
|
||||
resp = future.result()
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
else:
|
||||
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
|
@ -19,27 +19,25 @@ class SparkLLMClient:
|
|||
endpoint = 'chat'
|
||||
if api_domain:
|
||||
domain = api_domain
|
||||
if model == 'spark-v3':
|
||||
endpoint = 'multimodal'
|
||||
|
||||
model_api_configs = {
|
||||
'spark-1.5': {
|
||||
'spark-lite': {
|
||||
'version': 'v1.1',
|
||||
'chat_domain': 'general'
|
||||
},
|
||||
'spark-2': {
|
||||
'version': 'v2.1',
|
||||
'chat_domain': 'generalv2'
|
||||
},
|
||||
'spark-3': {
|
||||
'spark-pro': {
|
||||
'version': 'v3.1',
|
||||
'chat_domain': 'generalv3'
|
||||
},
|
||||
'spark-3.5': {
|
||||
'spark-pro-128k': {
|
||||
'version': 'pro-128k',
|
||||
'chat_domain': 'pro-128k'
|
||||
},
|
||||
'spark-max': {
|
||||
'version': 'v3.5',
|
||||
'chat_domain': 'generalv3.5'
|
||||
},
|
||||
'spark-4': {
|
||||
'spark-4.0-ultra': {
|
||||
'version': 'v4.0',
|
||||
'chat_domain': '4.0Ultra'
|
||||
}
|
||||
|
|
@ -48,7 +46,12 @@ class SparkLLMClient:
|
|||
api_version = model_api_configs[model]['version']
|
||||
|
||||
self.chat_domain = model_api_configs[model]['chat_domain']
|
||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
||||
|
||||
if model == 'spark-pro-128k':
|
||||
self.api_base = f"wss://{domain}/{endpoint}/{api_version}"
|
||||
else:
|
||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
||||
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
- spark-4.0-ultra
|
||||
- spark-max
|
||||
- spark-pro-128k
|
||||
- spark-pro
|
||||
- spark-lite
|
||||
- spark-4
|
||||
- spark-3.5
|
||||
- spark-3
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
model: spark-1.5
|
||||
deprecated: true
|
||||
label:
|
||||
en_US: Spark V1.5
|
||||
model_type: llm
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
model: spark-3.5
|
||||
deprecated: true
|
||||
label:
|
||||
en_US: Spark V3.5
|
||||
model_type: llm
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
model: spark-3
|
||||
deprecated: true
|
||||
label:
|
||||
en_US: Spark V3.0
|
||||
model_type: llm
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
model: spark-4.0-ultra
|
||||
label:
|
||||
en_US: Spark 4.0 Ultra
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
- name: show_ref_label
|
||||
label:
|
||||
zh_Hans: 联网检索
|
||||
en_US: web search
|
||||
type: boolean
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 该参数仅4.0 Ultra版本支持,当设置为true时,如果输入内容触发联网检索插件,会先返回检索信源列表,然后再返回星火回复结果,否则仅返回星火回复结果
|
||||
en_US: The parameter is only supported in the 4.0 Ultra version. When set to true, if the input triggers the online search plugin, it will first return a list of search sources and then return the Spark response. Otherwise, it will only return the Spark response.
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
model: spark-4
|
||||
deprecated: true
|
||||
label:
|
||||
en_US: Spark V4.0
|
||||
model_type: llm
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
model: spark-lite
|
||||
label:
|
||||
en_US: Spark Lite
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
model: spark-max
|
||||
label:
|
||||
en_US: Spark Max
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
model: spark-pro-128k
|
||||
label:
|
||||
en_US: Spark Pro-128K
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
model: spark-pro
|
||||
label:
|
||||
en_US: Spark Pro
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
|
|
@ -67,7 +67,7 @@ class FlashRecognitionRequest:
|
|||
|
||||
class FlashRecognizer:
|
||||
"""
|
||||
reponse:
|
||||
response:
|
||||
request_id string
|
||||
status Integer
|
||||
message String
|
||||
|
|
@ -132,9 +132,9 @@ class FlashRecognizer:
|
|||
signstr = self._format_sign_string(query)
|
||||
signature = self._sign(signstr, secret_key)
|
||||
header["Authorization"] = signature
|
||||
requrl = "https://"
|
||||
requrl += signstr[4::]
|
||||
return requrl
|
||||
req_url = "https://"
|
||||
req_url += signstr[4::]
|
||||
return req_url
|
||||
|
||||
def _create_query_arr(self, req):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from dashscope.common.error import (
|
|||
UnsupportedModel,
|
||||
)
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -64,88 +63,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
# invoke model without code wrapper
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
Wrapper for code block mode
|
||||
"""
|
||||
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
You should also complete the text started with ``` but not tell ``` directly.
|
||||
"""
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} with markdown codeblocks.")
|
||||
))
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
max: 6000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue