mirror of https://github.com/langgenius/dify.git
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/tests/integration_tests/workflow/nodes/test_code.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
commit
4771e85630
|
|
@ -1,3 +1,3 @@
|
|||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
|
|
@ -24,44 +23,16 @@ class DifyConfig(
|
|||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
):
|
||||
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra='ignore',
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_DEPTH: int = 5
|
||||
CODE_MAX_PRECISION: int = 20
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
||||
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
|
||||
|
||||
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
|
|
|
|||
|
|
@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings):
|
|||
"""
|
||||
Deployment configs
|
||||
"""
|
||||
|
||||
APPLICATION_NAME: str = Field(
|
||||
description='application name',
|
||||
default='langgenius/dify',
|
||||
description="application name",
|
||||
default="langgenius/dify",
|
||||
)
|
||||
|
||||
DEBUG: bool = Field(
|
||||
description="whether to enable debug mode.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
TESTING: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description='deployment edition',
|
||||
default='SELF_HOSTED',
|
||||
description="deployment edition",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
DEPLOY_ENV: str = Field(
|
||||
description='deployment environment, default to PRODUCTION.',
|
||||
default='PRODUCTION',
|
||||
description="deployment environment, default to PRODUCTION.",
|
||||
default="PRODUCTION",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings):
|
|||
Enterprise feature configs.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
"""
|
||||
|
||||
ENTERPRISE_ENABLED: bool = Field(
|
||||
description='whether to enable enterprise features.'
|
||||
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
|
||||
description="whether to enable enterprise features."
|
||||
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CAN_REPLACE_LOGO: bool = Field(
|
||||
description='whether to allow replacing enterprise logo.',
|
||||
description="whether to allow replacing enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,27 +8,28 @@ class NotionConfig(BaseSettings):
|
|||
"""
|
||||
Notion integration configs
|
||||
"""
|
||||
|
||||
NOTION_CLIENT_ID: Optional[str] = Field(
|
||||
description='Notion client ID',
|
||||
description="Notion client ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='Notion client secret key',
|
||||
description="Notion client secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
|
||||
description='Notion integration type, default to None, available values: internal.',
|
||||
description="Notion integration type, default to None, available values: internal.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTERNAL_SECRET: Optional[str] = Field(
|
||||
description='Notion internal secret key',
|
||||
description="Notion internal secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
|
||||
description='Notion integration token',
|
||||
description="Notion integration token",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,17 +8,18 @@ class SentryConfig(BaseSettings):
|
|||
"""
|
||||
Sentry configs
|
||||
"""
|
||||
|
||||
SENTRY_DSN: Optional[str] = Field(
|
||||
description='Sentry DSN',
|
||||
description="Sentry DSN",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description='Sentry trace sample rate',
|
||||
description="Sentry trace sample rate",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description='Sentry profiles sample rate',
|
||||
description="Sentry profiles sample rate",
|
||||
default=1.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
|
@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings):
|
|||
"""
|
||||
Secret Key configs
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
description='Your App secret key will be used for securely signing the session cookie'
|
||||
'Make sure you are changing this key for your deployment with a strong key.'
|
||||
'You can generate a strong key using `openssl rand -base64 42`.'
|
||||
'Alternatively you can set it with `SECRET_KEY` environment variable.',
|
||||
description="Your App secret key will be used for securely signing the session cookie"
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"You can generate a strong key using `openssl rand -base64 42`."
|
||||
"Alternatively you can set it with `SECRET_KEY` environment variable.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='Expiry time in hours for reset token',
|
||||
description="Expiry time in hours for reset token",
|
||||
default=24,
|
||||
)
|
||||
|
||||
|
|
@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings):
|
|||
"""
|
||||
App Execution configs
|
||||
"""
|
||||
|
||||
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description='execution timeout in seconds for app execution',
|
||||
description="execution timeout in seconds for app execution",
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
description="max active request per app, 0 means unlimited",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
|
@ -42,14 +44,55 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
"""
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: str = Field(
|
||||
description='endpoint URL of code execution servcie',
|
||||
default='http://sandbox:8194',
|
||||
description="endpoint URL of code execution servcie",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_API_KEY: str = Field(
|
||||
description='API key for code execution service',
|
||||
default='dify-sandbox',
|
||||
description="API key for code execution service",
|
||||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MIN_NUMBER: NegativeInt = Field(
|
||||
description="",
|
||||
default=-9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
CODE_MAX_PRECISION: PositiveInt = Field(
|
||||
description="max precision digits for float type in code execution",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="max string length for code execution",
|
||||
default=80000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -57,28 +100,27 @@ class EndpointConfig(BaseSettings):
|
|||
"""
|
||||
Module URL configs
|
||||
"""
|
||||
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description='The backend URL prefix of the console API.'
|
||||
'used to concatenate the login authorization callback or notion integration callback.',
|
||||
default='',
|
||||
description="The backend URL prefix of the console API."
|
||||
"used to concatenate the login authorization callback or notion integration callback.",
|
||||
default="",
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description='The front-end URL prefix of the console web.'
|
||||
'used to concatenate some front-end addresses and for CORS configuration use.',
|
||||
default='',
|
||||
description="The front-end URL prefix of the console web."
|
||||
"used to concatenate some front-end addresses and for CORS configuration use.",
|
||||
default="",
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description='Service API Url prefix.'
|
||||
'used to display Service API Base Url to the front-end.',
|
||||
default='',
|
||||
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description='WebApp Url prefix.'
|
||||
'used to display WebAPP API Base Url to the front-end.',
|
||||
default='',
|
||||
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -86,17 +128,18 @@ class FileAccessConfig(BaseSettings):
|
|||
"""
|
||||
File Access configs
|
||||
"""
|
||||
|
||||
FILES_URL: str = Field(
|
||||
description='File preview or download Url prefix.'
|
||||
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
|
||||
'Url is signed and has expiration time.',
|
||||
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
|
||||
description="File preview or download Url prefix."
|
||||
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
|
||||
"Url is signed and has expiration time.",
|
||||
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
|
||||
alias_priority=1,
|
||||
default='',
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description='timeout in seconds for file accessing',
|
||||
description="timeout in seconds for file accessing",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
|
@ -105,23 +148,24 @@ class FileUploadConfig(BaseSettings):
|
|||
"""
|
||||
File Uploading configs
|
||||
"""
|
||||
|
||||
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description='size limit in Megabytes for uploading files',
|
||||
description="size limit in Megabytes for uploading files",
|
||||
default=15,
|
||||
)
|
||||
|
||||
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
|
||||
description='batch size limit for uploading files',
|
||||
description="batch size limit for uploading files",
|
||||
default=5,
|
||||
)
|
||||
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description='image file size limit in Megabytes for uploading files',
|
||||
description="image file size limit in Megabytes for uploading files",
|
||||
default=10,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description='', # todo: to be clarified
|
||||
description="", # todo: to be clarified
|
||||
default=20,
|
||||
)
|
||||
|
||||
|
|
@ -130,45 +174,82 @@ class HttpConfig(BaseSettings):
|
|||
"""
|
||||
HTTP configs
|
||||
"""
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description='whether to enable HTTP response compression of gzip',
|
||||
description="whether to enable HTTP response compression of gzip",
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description='',
|
||||
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
|
||||
default='',
|
||||
description="",
|
||||
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description='',
|
||||
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
|
||||
default='*',
|
||||
description="",
|
||||
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
|
||||
default="*",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=300,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="HTTP URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
|
||||
description="HTTPS URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
"""
|
||||
|
||||
INNER_API: bool = Field(
|
||||
description='whether to enable the inner API',
|
||||
description="whether to enable the inner API",
|
||||
default=False,
|
||||
)
|
||||
|
||||
INNER_API_KEY: Optional[str] = Field(
|
||||
description='The inner API key is used to authenticate the inner API',
|
||||
description="The inner API key is used to authenticate the inner API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -179,28 +260,27 @@ class LoggingConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description='Log output level, default to INFO.'
|
||||
'It is recommended to set it to ERROR for production.',
|
||||
default='INFO',
|
||||
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_FILE: Optional[str] = Field(
|
||||
description='logging output file path',
|
||||
description="logging output file path",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description='log format',
|
||||
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
|
||||
description="log format",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: Optional[str] = Field(
|
||||
description='log date format',
|
||||
description="log date format",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_TZ: Optional[str] = Field(
|
||||
description='specify log timezone, eg: America/New_York',
|
||||
description="specify log timezone, eg: America/New_York",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -209,8 +289,9 @@ class ModelLoadBalanceConfig(BaseSettings):
|
|||
"""
|
||||
Model load balance configs
|
||||
"""
|
||||
|
||||
MODEL_LB_ENABLED: bool = Field(
|
||||
description='whether to enable model load balancing',
|
||||
description="whether to enable model load balancing",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -219,8 +300,9 @@ class BillingConfig(BaseSettings):
|
|||
"""
|
||||
Platform Billing Configurations
|
||||
"""
|
||||
|
||||
BILLING_ENABLED: bool = Field(
|
||||
description='whether to enable billing',
|
||||
description="whether to enable billing",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -229,9 +311,10 @@ class UpdateConfig(BaseSettings):
|
|||
"""
|
||||
Update configs
|
||||
"""
|
||||
|
||||
CHECK_UPDATE_URL: str = Field(
|
||||
description='url for checking updates',
|
||||
default='https://updates.dify.ai',
|
||||
description="url for checking updates",
|
||||
default="https://updates.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -241,47 +324,53 @@ class WorkflowConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
|
||||
description='max execution steps in single workflow execution',
|
||||
description="max execution steps in single workflow execution",
|
||||
default=500,
|
||||
)
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description='max execution time in seconds in single workflow execution',
|
||||
description="max execution time in seconds in single workflow execution",
|
||||
default=1200,
|
||||
)
|
||||
|
||||
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
|
||||
description='max depth of calling in single workflow execution',
|
||||
description="max depth of calling in single workflow execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="The maximum size in bytes of a variable. default to 5KB.",
|
||||
default=5 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
description='redirect path for OAuth',
|
||||
default='/console/api/oauth/authorize',
|
||||
description="redirect path for OAuth",
|
||||
default="/console/api/oauth/authorize",
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description='GitHub client id for OAuth',
|
||||
description="GitHub client id for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='GitHub client secret key for OAuth',
|
||||
description="GitHub client secret key for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_ID: Optional[str] = Field(
|
||||
description='Google client id for OAuth',
|
||||
description="Google client id for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='Google client secret key for OAuth',
|
||||
description="Google client secret key for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -291,9 +380,8 @@ class ModerationConfig(BaseSettings):
|
|||
Moderation in app configs.
|
||||
"""
|
||||
|
||||
# todo: to be clarified in usage and unit
|
||||
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description='buffer size for moderation',
|
||||
MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description="buffer size for moderation",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
|
@ -304,7 +392,7 @@ class ToolConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
|
||||
description='max age in seconds for tool icon caching',
|
||||
description="max age in seconds for tool icon caching",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
|
@ -315,52 +403,52 @@ class MailConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
|
||||
description='default email address for sending from ',
|
||||
description="default email address for sending from ",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_KEY: Optional[str] = Field(
|
||||
description='API key for Resend',
|
||||
description="API key for Resend",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_URL: Optional[str] = Field(
|
||||
description='API URL for Resend',
|
||||
description="API URL for Resend",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_SERVER: Optional[str] = Field(
|
||||
description='smtp server host',
|
||||
description="smtp server host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PORT: Optional[int] = Field(
|
||||
description='smtp server port',
|
||||
description="smtp server port",
|
||||
default=465,
|
||||
)
|
||||
|
||||
SMTP_USERNAME: Optional[str] = Field(
|
||||
description='smtp server username',
|
||||
description="smtp server username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PASSWORD: Optional[str] = Field(
|
||||
description='smtp server password',
|
||||
description="smtp server password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_USE_TLS: bool = Field(
|
||||
description='whether to use TLS connection to smtp server',
|
||||
description="whether to use TLS connection to smtp server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_OPPORTUNISTIC_TLS: bool = Field(
|
||||
description='whether to use opportunistic TLS connection to smtp server',
|
||||
description="whether to use opportunistic TLS connection to smtp server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -371,22 +459,22 @@ class RagEtlConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ETL_TYPE: str = Field(
|
||||
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
|
||||
default='dify',
|
||||
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
KEYWORD_DATA_SOURCE_TYPE: str = Field(
|
||||
description='source type for keyword data, default to `database`, available values are `database` .',
|
||||
default='database',
|
||||
description="source type for keyword data, default to `database`, available values are `database` .",
|
||||
default="database",
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_URL: Optional[str] = Field(
|
||||
description='API URL for Unstructured',
|
||||
description="API URL for Unstructured",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(
|
||||
description='API key for Unstructured',
|
||||
description="API key for Unstructured",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -397,12 +485,12 @@ class DataSetConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description='interval in days for cleaning up dataset',
|
||||
description="interval in days for cleaning up dataset",
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
description="whether to enable dataset operator",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -413,7 +501,7 @@ class WorkspaceConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
INVITE_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='workspaces invitation expiration in hours',
|
||||
description="workspaces invitation expiration in hours",
|
||||
default=72,
|
||||
)
|
||||
|
||||
|
|
@ -424,80 +512,79 @@ class IndexingConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
|
||||
description='max segmentation token length for indexing',
|
||||
description="max segmentation token length for indexing",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description='multi model send image format, support base64, url, default is base64',
|
||||
default='base64',
|
||||
description="multi model send image format, support base64, url, default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description='the time of the celery scheduler, default to 1 day',
|
||||
description="the time of the celery scheduler, default to 1 day",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description='The heads of model providers',
|
||||
default='',
|
||||
description="The heads of model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description='The included model providers',
|
||||
default='',
|
||||
description="The included model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description='The excluded model providers',
|
||||
default='',
|
||||
description="The excluded model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description='The heads of tools',
|
||||
default='',
|
||||
description="The heads of tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description='The included tools',
|
||||
default='',
|
||||
description="The included tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description='The excluded tools',
|
||||
default='',
|
||||
description="The excluded tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
|
|
@ -525,7 +612,6 @@ class FeatureConfig(
|
|||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
|
|
|||
|
|
@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description='',
|
||||
default='gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'text-davinci-003',
|
||||
description="",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=200,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_MODELS: str = Field(
|
||||
description='',
|
||||
default='gpt-4,'
|
||||
'gpt-4-turbo-preview,'
|
||||
'gpt-4-turbo-2024-04-09,'
|
||||
'gpt-4-1106-preview,'
|
||||
'gpt-4-0125-preview,'
|
||||
'gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'text-davinci-003',
|
||||
description="",
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=200,
|
||||
)
|
||||
|
||||
|
|
@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_MINIMAX_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_SPARK_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_ZHIPUAI_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_MODERATION_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_MODERATION_PROVIDERS: str = Field(
|
||||
description='',
|
||||
default='',
|
||||
description="",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description='the mode for fetching app templates,'
|
||||
' default to remote,'
|
||||
' available values: remote, db, builtin',
|
||||
default='remote',
|
||||
description="the mode for fetching app templates,"
|
||||
" default to remote,"
|
||||
" available values: remote, db, builtin",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description='the domain for fetching remote app templates',
|
||||
default='https://tmpl.dify.ai',
|
||||
description="the domain for fetching remote app templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -202,7 +202,6 @@ class HostedServiceConfig(
|
|||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
HostedZhipuAIConfig,
|
||||
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -28,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
|||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description='storage type,'
|
||||
' default to `local`,'
|
||||
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
|
||||
default='local',
|
||||
description="storage type,"
|
||||
" default to `local`,"
|
||||
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description='local storage path',
|
||||
default='storage',
|
||||
description="local storage path",
|
||||
default="storage",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description='vector store type',
|
||||
description="vector store type",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description='keyword store type',
|
||||
default='jieba',
|
||||
description="keyword store type",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description='db host',
|
||||
default='localhost',
|
||||
description="db host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
DB_PORT: PositiveInt = Field(
|
||||
description='db port',
|
||||
description="db port",
|
||||
default=5432,
|
||||
)
|
||||
|
||||
DB_USERNAME: str = Field(
|
||||
description='db username',
|
||||
default='postgres',
|
||||
description="db username",
|
||||
default="postgres",
|
||||
)
|
||||
|
||||
DB_PASSWORD: str = Field(
|
||||
description='db password',
|
||||
default='',
|
||||
description="db password",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description='db database',
|
||||
default='dify',
|
||||
description="db database",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
DB_CHARSET: str = Field(
|
||||
description='db charset',
|
||||
default='',
|
||||
description="db charset",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description='db extras options. Example: keepalives_idle=60&keepalives=1',
|
||||
default='',
|
||||
description="db extras options. Example: keepalives_idle=60&keepalives=1",
|
||||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description='db uri scheme',
|
||||
default='postgresql',
|
||||
description="db uri scheme",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
|
||||
if self.DB_CHARSET
|
||||
else self.DB_EXTRAS
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
|
||||
).strip("&")
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}")
|
||||
return (
|
||||
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}"
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
|
||||
description='pool size of SqlAlchemy',
|
||||
description="pool size of SqlAlchemy",
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
|
||||
description='max overflows for SqlAlchemy',
|
||||
description="max overflows for SqlAlchemy",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
|
||||
description='SqlAlchemy pool recycle',
|
||||
description="SqlAlchemy pool recycle",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING: bool = Field(
|
||||
description='whether to enable pool pre-ping in SqlAlchemy',
|
||||
description="whether to enable pool pre-ping in SqlAlchemy",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description='whether to enable SqlAlchemy echo',
|
||||
description="whether to enable SqlAlchemy echo",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
|
@ -137,35 +137,38 @@ class DatabaseConfig:
|
|||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
return {
|
||||
'pool_size': self.SQLALCHEMY_POOL_SIZE,
|
||||
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
|
||||
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
|
||||
'connect_args': {'options': '-c timezone=UTC'},
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": {"options": "-c timezone=UTC"},
|
||||
}
|
||||
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description='Celery backend, available values are `database`, `redis`',
|
||||
default='database',
|
||||
description="Celery backend, available values are `database`, `redis`",
|
||||
default="database",
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: Optional[str] = Field(
|
||||
description='CELERY_BROKER_URL',
|
||||
description="CELERY_BROKER_URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
return (
|
||||
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
|
||||
if self.CELERY_BACKEND == "database"
|
||||
else self.CELERY_BROKER_URL
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def BROKER_USE_SSL(self) -> bool:
|
||||
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
|
||||
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
|
|
@ -174,7 +177,6 @@ class MiddlewareConfig(
|
|||
DatabaseConfig,
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
|
|
@ -183,7 +185,6 @@ class MiddlewareConfig(
|
|||
TencentCloudCOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
|
|
|
|||
|
|
@ -8,32 +8,33 @@ class RedisConfig(BaseSettings):
|
|||
"""
|
||||
Redis configs
|
||||
"""
|
||||
|
||||
REDIS_HOST: str = Field(
|
||||
description='Redis host',
|
||||
default='localhost',
|
||||
description="Redis host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
REDIS_PORT: PositiveInt = Field(
|
||||
description='Redis port',
|
||||
description="Redis port",
|
||||
default=6379,
|
||||
)
|
||||
|
||||
REDIS_USERNAME: Optional[str] = Field(
|
||||
description='Redis username',
|
||||
description="Redis username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_PASSWORD: Optional[str] = Field(
|
||||
description='Redis password',
|
||||
description="Redis password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_DB: NonNegativeInt = Field(
|
||||
description='Redis database id, default to 0',
|
||||
description="Redis database id, default to 0",
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description='whether to use SSL for Redis connection',
|
||||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,31 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Aliyun OSS bucket name',
|
||||
description="Aliyun OSS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
|
||||
description='Aliyun OSS access key',
|
||||
description="Aliyun OSS access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
|
||||
description='Aliyun OSS secret key',
|
||||
description="Aliyun OSS secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
|
||||
description='Aliyun OSS endpoint URL',
|
||||
description="Aliyun OSS endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_REGION: Optional[str] = Field(
|
||||
description='Aliyun OSS region',
|
||||
description="Aliyun OSS region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
|
||||
description='Aliyun OSS authentication version',
|
||||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
S3_ENDPOINT: Optional[str] = Field(
|
||||
description='S3 storage endpoint',
|
||||
description="S3 storage endpoint",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_REGION: Optional[str] = Field(
|
||||
description='S3 storage region',
|
||||
description="S3 storage region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_BUCKET_NAME: Optional[str] = Field(
|
||||
description='S3 storage bucket name',
|
||||
description="S3 storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ACCESS_KEY: Optional[str] = Field(
|
||||
description='S3 storage access key',
|
||||
description="S3 storage access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_SECRET_KEY: Optional[str] = Field(
|
||||
description='S3 storage secret key',
|
||||
description="S3 storage secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ADDRESS_STYLE: str = Field(
|
||||
description='S3 storage address style',
|
||||
default='auto',
|
||||
description="S3 storage address style",
|
||||
default="auto",
|
||||
)
|
||||
|
||||
S3_USE_AWS_MANAGED_IAM: bool = Field(
|
||||
description='whether to use aws managed IAM for S3',
|
||||
description="whether to use aws managed IAM for S3",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
|
||||
description='Azure Blob account name',
|
||||
description="Azure Blob account name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
|
||||
description='Azure Blob account key',
|
||||
description="Azure Blob account key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
|
||||
description='Azure Blob container name',
|
||||
description="Azure Blob container name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
|
||||
description='Azure Blob account URL',
|
||||
description="Azure Blob account URL",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Google Cloud storage bucket name',
|
||||
description="Google Cloud storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
|
||||
description='Google Cloud storage service account json base64',
|
||||
description="Google Cloud storage service account json base64",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
OCI_ENDPOINT: Optional[str] = Field(
|
||||
description='OCI storage endpoint',
|
||||
description="OCI storage endpoint",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_REGION: Optional[str] = Field(
|
||||
description='OCI storage region',
|
||||
description="OCI storage region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||
description='OCI storage bucket name',
|
||||
description="OCI storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||
description='OCI storage access key',
|
||||
description="OCI storage access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_SECRET_KEY: Optional[str] = Field(
|
||||
description='OCI storage secret key',
|
||||
description="OCI storage secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Tencent Cloud COS bucket name',
|
||||
description="Tencent Cloud COS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_REGION: Optional[str] = Field(
|
||||
description='Tencent Cloud COS region',
|
||||
description="Tencent Cloud COS region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_ID: Optional[str] = Field(
|
||||
description='Tencent Cloud COS secret id',
|
||||
description="Tencent Cloud COS secret id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
|
||||
description='Tencent Cloud COS secret key',
|
||||
description="Tencent Cloud COS secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SCHEME: Optional[str] = Field(
|
||||
description='Tencent Cloud COS scheme',
|
||||
description="Tencent Cloud COS scheme",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel):
|
|||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
ANALYTICDB_KEY_ID: Optional[str] = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
ANALYTICDB_REGION_ID: Optional[str] = Field(
|
||||
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
ANALYTICDB_ACCOUNT: Optional[str] = Field(
|
||||
default=None, description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
ANALYTICDB_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
ANALYTICDB_NAMESPACE: Optional[str] = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CHROMA_HOST: Optional[str] = Field(
|
||||
description='Chroma host',
|
||||
description="Chroma host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_PORT: PositiveInt = Field(
|
||||
description='Chroma port',
|
||||
description="Chroma port",
|
||||
default=8000,
|
||||
)
|
||||
|
||||
CHROMA_TENANT: Optional[str] = Field(
|
||||
description='Chroma database',
|
||||
description="Chroma database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_DATABASE: Optional[str] = Field(
|
||||
description='Chroma database',
|
||||
description="Chroma database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
|
||||
description='Chroma authentication provider',
|
||||
description="Chroma authentication provider",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
|
||||
description='Chroma authentication credentials',
|
||||
description="Chroma authentication credentials",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description='Milvus host',
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description='Milvus RestFul API port',
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
description='Milvus user',
|
||||
description="Milvus user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PASSWORD: Optional[str] = Field(
|
||||
description='Milvus password',
|
||||
description="Milvus password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description='whether to use SSL connection for Milvus',
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description='Milvus database, default to `default`',
|
||||
default='default',
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
|
|
@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel):
|
|||
"""
|
||||
|
||||
MYSCALE_HOST: str = Field(
|
||||
description='MyScale host',
|
||||
default='localhost',
|
||||
description="MyScale host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description='MyScale port',
|
||||
description="MyScale port",
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: str = Field(
|
||||
description='MyScale user',
|
||||
default='default',
|
||||
description="MyScale user",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description='MyScale password',
|
||||
default='',
|
||||
description="MyScale password",
|
||||
default="",
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description='MyScale database name',
|
||||
default='default',
|
||||
description="MyScale database name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default='',
|
||||
description="MyScale fts index parameters",
|
||||
default="",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
OPENSEARCH_HOST: Optional[str] = Field(
|
||||
description='OpenSearch host',
|
||||
description="OpenSearch host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PORT: PositiveInt = Field(
|
||||
description='OpenSearch port',
|
||||
description="OpenSearch port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
OPENSEARCH_USER: Optional[str] = Field(
|
||||
description='OpenSearch user',
|
||||
description="OpenSearch user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description='OpenSearch password',
|
||||
description="OpenSearch password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description='whether to use SSL connection for OpenSearch',
|
||||
description="whether to use SSL connection for OpenSearch",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ORACLE_HOST: Optional[str] = Field(
|
||||
description='ORACLE host',
|
||||
description="ORACLE host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
description='ORACLE port',
|
||||
description="ORACLE port",
|
||||
default=1521,
|
||||
)
|
||||
|
||||
ORACLE_USER: Optional[str] = Field(
|
||||
description='ORACLE user',
|
||||
description="ORACLE user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PASSWORD: Optional[str] = Field(
|
||||
description='ORACLE password',
|
||||
description="ORACLE password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_DATABASE: Optional[str] = Field(
|
||||
description='ORACLE database',
|
||||
description="ORACLE database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
PGVECTOR_HOST: Optional[str] = Field(
|
||||
description='PGVector host',
|
||||
description="PGVector host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description='PGVector port',
|
||||
description="PGVector port",
|
||||
default=5433,
|
||||
)
|
||||
|
||||
PGVECTOR_USER: Optional[str] = Field(
|
||||
description='PGVector user',
|
||||
description="PGVector user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PASSWORD: Optional[str] = Field(
|
||||
description='PGVector password',
|
||||
description="PGVector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_DATABASE: Optional[str] = Field(
|
||||
description='PGVector database',
|
||||
description="PGVector database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
PGVECTO_RS_HOST: Optional[str] = Field(
|
||||
description='PGVectoRS host',
|
||||
description="PGVectoRS host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
description='PGVectoRS port',
|
||||
description="PGVectoRS port",
|
||||
default=5431,
|
||||
)
|
||||
|
||||
PGVECTO_RS_USER: Optional[str] = Field(
|
||||
description='PGVectoRS user',
|
||||
description="PGVectoRS user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PASSWORD: Optional[str] = Field(
|
||||
description='PGVectoRS password',
|
||||
description="PGVectoRS password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_DATABASE: Optional[str] = Field(
|
||||
description='PGVectoRS database',
|
||||
description="PGVectoRS database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
QDRANT_URL: Optional[str] = Field(
|
||||
description='Qdrant url',
|
||||
description="Qdrant url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_API_KEY: Optional[str] = Field(
|
||||
description='Qdrant api key',
|
||||
description="Qdrant api key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description='Qdrant client timeout in seconds',
|
||||
description="Qdrant client timeout in seconds",
|
||||
default=20,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description='whether enable grpc support for Qdrant connection',
|
||||
description="whether enable grpc support for Qdrant connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description='Qdrant grpc port',
|
||||
description="Qdrant grpc port",
|
||||
default=6334,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
RELYT_HOST: Optional[str] = Field(
|
||||
description='Relyt host',
|
||||
description="Relyt host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PORT: PositiveInt = Field(
|
||||
description='Relyt port',
|
||||
description="Relyt port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
RELYT_USER: Optional[str] = Field(
|
||||
description='Relyt user',
|
||||
description="Relyt user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PASSWORD: Optional[str] = Field(
|
||||
description='Relyt password',
|
||||
description="Relyt password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_DATABASE: Optional[str] = Field(
|
||||
description='Relyt database',
|
||||
default='default',
|
||||
description="Relyt database",
|
||||
default="default",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
|
||||
description='Tencent Vector URL',
|
||||
description="Tencent Vector URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description='Tencent Vector API key',
|
||||
description="Tencent Vector API key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
|
||||
description='Tencent Vector timeout in seconds',
|
||||
description="Tencent Vector timeout in seconds",
|
||||
default=30,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description='Tencent Vector username',
|
||||
description="Tencent Vector username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
|
||||
description='Tencent Vector password',
|
||||
description="Tencent Vector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description='Tencent Vector sharding number',
|
||||
description="Tencent Vector sharding number",
|
||||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description='Tencent Vector replicas',
|
||||
description="Tencent Vector replicas",
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description='Tencent Vector Database',
|
||||
description="Tencent Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TIDB_VECTOR_HOST: Optional[str] = Field(
|
||||
description='TiDB Vector host',
|
||||
description="TiDB Vector host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description='TiDB Vector port',
|
||||
description="TiDB Vector port",
|
||||
default=4000,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_USER: Optional[str] = Field(
|
||||
description='TiDB Vector user',
|
||||
description="TiDB Vector user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
|
||||
description='TiDB Vector password',
|
||||
description="TiDB Vector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_DATABASE: Optional[str] = Field(
|
||||
description='TiDB Vector database',
|
||||
description="TiDB Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
WEAVIATE_ENDPOINT: Optional[str] = Field(
|
||||
description='Weaviate endpoint URL',
|
||||
description="Weaviate endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_API_KEY: Optional[str] = Field(
|
||||
description='Weaviate API key',
|
||||
description="Weaviate API key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description='whether to enable gRPC for Weaviate connection',
|
||||
description="whether to enable gRPC for Weaviate connection",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description='Weaviate batch size',
|
||||
description="Weaviate batch size",
|
||||
default=100,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
|
|||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.7.1',
|
||||
description="Dify version",
|
||||
default="0.7.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
description="SHA-1 checksum of the git commit used to build the app",
|
||||
default='',
|
||||
default="",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,18 @@ parameter_rules:
|
|||
default: 1024
|
||||
min: 1
|
||||
max: 128000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.06'
|
||||
output: '0.06'
|
||||
|
|
|
|||
|
|
@ -21,6 +21,18 @@ parameter_rules:
|
|||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.024'
|
||||
output: '0.024'
|
||||
|
|
|
|||
|
|
@ -21,6 +21,18 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.012'
|
||||
output: '0.012'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
model: text-embedding-v3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 25
|
||||
pricing:
|
||||
input: "0.0007"
|
||||
unit: "0.001"
|
||||
currency: RMB
|
||||
|
|
@ -444,6 +444,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class LangfuseConfig(BaseTracingConfig):
|
|||
"""
|
||||
public_key: str
|
||||
secret_key: str
|
||||
project_key: str
|
||||
host: str = 'https://api.langfuse.com'
|
||||
|
||||
@field_validator("host")
|
||||
|
|
|
|||
|
|
@ -419,3 +419,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
except Exception as e:
|
||||
logger.debug(f"LangFuse API check failed: {str(e)}")
|
||||
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
||||
|
||||
def get_project_key(self):
|
||||
try:
|
||||
projects = self.langfuse_client.client.projects.get()
|
||||
return projects.data[0].id
|
||||
except Exception as e:
|
||||
logger.debug(f"LangFuse get project key failed: {str(e)}")
|
||||
raise ValueError(f"LangFuse get project key failed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ provider_config_map = {
|
|||
TracingProviderEnum.LANGFUSE.value: {
|
||||
'config_class': LangfuseConfig,
|
||||
'secret_keys': ['public_key', 'secret_key'],
|
||||
'other_keys': ['host'],
|
||||
'other_keys': ['host', 'project_key'],
|
||||
'trace_instance': LangFuseDataTrace
|
||||
},
|
||||
TracingProviderEnum.LANGSMITH.value: {
|
||||
|
|
@ -123,7 +123,6 @@ class OpsTraceManager:
|
|||
|
||||
for key in other_keys:
|
||||
new_config[key] = decrypt_tracing_config.get(key, "")
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
|
||||
@classmethod
|
||||
|
|
@ -252,6 +251,19 @@ class OpsTraceManager:
|
|||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).api_check()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
||||
"""
|
||||
get trace config is project key
|
||||
:param tracing_config: tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['trace_instance']
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).get_project_key()
|
||||
|
||||
|
||||
class TraceTask:
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ class DatasetRetrieval:
|
|||
top_k: int, score_threshold: float) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if score_threshold and document.metadata['score'] >= score_threshold:
|
||||
if score_threshold is None or document.metadata['score'] >= score_threshold:
|
||||
filter_documents.append(document)
|
||||
if not filter_documents:
|
||||
return []
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 37 KiB |
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class OneBotProvider(BuiltinToolProviderController):
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
|
||||
if not credentials.get("ob11_http_url"):
|
||||
raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.')
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
identity:
|
||||
author: RockChinQ
|
||||
name: onebot
|
||||
label:
|
||||
en_US: OneBot v11 Protocol
|
||||
zh_Hans: OneBot v11 协议
|
||||
description:
|
||||
en_US: Unofficial OneBot v11 Protocol Tool
|
||||
zh_Hans: 非官方 OneBot v11 协议工具
|
||||
icon: icon.ico
|
||||
credentials_for_provider:
|
||||
ob11_http_url:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: HTTP URL
|
||||
zh_Hans: HTTP URL
|
||||
description:
|
||||
en_US: Forward HTTP URL of OneBot v11
|
||||
zh_Hans: OneBot v11 正向 HTTP URL
|
||||
help:
|
||||
en_US: Fill this with the HTTP URL of your OneBot server
|
||||
zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL
|
||||
access_token:
|
||||
type: secret-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Access Token
|
||||
zh_Hans: 访问令牌
|
||||
description:
|
||||
en_US: Access Token for OneBot v11 Protocol
|
||||
zh_Hans: OneBot 协议访问令牌
|
||||
help:
|
||||
en_US: Fill this if you set a access token in your OneBot server
|
||||
zh_Hans: 如果你在 OneBot 服务器中设置了 access token,请填写此项
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SendGroupMsg(BuiltinTool):
|
||||
"""OneBot v11 Tool: Send Group Message"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
# Get parameters
|
||||
send_group_id = tool_parameters.get('group_id', '')
|
||||
|
||||
message = tool_parameters.get('message', '')
|
||||
if not message:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': 'Message is empty.'
|
||||
}
|
||||
)
|
||||
|
||||
auto_escape = tool_parameters.get('auto_escape', False)
|
||||
|
||||
try:
|
||||
url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg'
|
||||
|
||||
resp = requests.post(
|
||||
url,
|
||||
json={
|
||||
'group_id': send_group_id,
|
||||
'message': message,
|
||||
'auto_escape': auto_escape
|
||||
},
|
||||
headers={
|
||||
'Authorization': 'Bearer ' + self.runtime.credentials['access_token']
|
||||
}
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': f'Failed to send group message: {resp.text}'
|
||||
}
|
||||
)
|
||||
|
||||
return self.create_json_message(
|
||||
{
|
||||
'response': resp.json()
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': f'Failed to send group message: {e}'
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
identity:
|
||||
name: send_group_msg
|
||||
author: RockChinQ
|
||||
label:
|
||||
en_US: Send Group Message
|
||||
zh_Hans: 发送群消息
|
||||
description:
|
||||
human:
|
||||
en_US: Send a message to a group
|
||||
zh_Hans: 发送消息到群聊
|
||||
llm: A tool for sending a message segment to a group
|
||||
parameters:
|
||||
- name: group_id
|
||||
type: number
|
||||
required: true
|
||||
label:
|
||||
en_US: Target Group ID
|
||||
zh_Hans: 目标群 ID
|
||||
human_description:
|
||||
en_US: The group ID of the target group
|
||||
zh_Hans: 目标群的群 ID
|
||||
llm_description: The group ID of the target group
|
||||
form: llm
|
||||
- name: message
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Message
|
||||
zh_Hans: 消息
|
||||
human_description:
|
||||
en_US: The message to send
|
||||
zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true)
|
||||
llm_description: The message to send
|
||||
form: llm
|
||||
- name: auto_escape
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Auto Escape
|
||||
zh_Hans: 自动转义
|
||||
human_description:
|
||||
en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes.
|
||||
zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。
|
||||
llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending.
|
||||
form: form
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SendPrivateMsg(BuiltinTool):
|
||||
"""OneBot v11 Tool: Send Private Message"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
# Get parameters
|
||||
send_user_id = tool_parameters.get('user_id', '')
|
||||
|
||||
message = tool_parameters.get('message', '')
|
||||
if not message:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': 'Message is empty.'
|
||||
}
|
||||
)
|
||||
|
||||
auto_escape = tool_parameters.get('auto_escape', False)
|
||||
|
||||
try:
|
||||
url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg'
|
||||
|
||||
resp = requests.post(
|
||||
url,
|
||||
json={
|
||||
'user_id': send_user_id,
|
||||
'message': message,
|
||||
'auto_escape': auto_escape
|
||||
},
|
||||
headers={
|
||||
'Authorization': 'Bearer ' + self.runtime.credentials['access_token']
|
||||
}
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': f'Failed to send private message: {resp.text}'
|
||||
}
|
||||
)
|
||||
|
||||
return self.create_json_message(
|
||||
{
|
||||
'response': resp.json()
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_json_message(
|
||||
{
|
||||
'error': f'Failed to send private message: {e}'
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
identity:
|
||||
name: send_private_msg
|
||||
author: RockChinQ
|
||||
label:
|
||||
en_US: Send Private Message
|
||||
zh_Hans: 发送私聊消息
|
||||
description:
|
||||
human:
|
||||
en_US: Send a private message to a user
|
||||
zh_Hans: 发送私聊消息给用户
|
||||
llm: A tool for sending a message segment to a user in private chat
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: number
|
||||
required: true
|
||||
label:
|
||||
en_US: Target User ID
|
||||
zh_Hans: 目标用户 ID
|
||||
human_description:
|
||||
en_US: The user ID of the target user
|
||||
zh_Hans: 目标用户的用户 ID
|
||||
llm_description: The user ID of the target user
|
||||
form: llm
|
||||
- name: message
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Message
|
||||
zh_Hans: 消息
|
||||
human_description:
|
||||
en_US: The message to send
|
||||
zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true)
|
||||
llm_description: The message to send
|
||||
form: llm
|
||||
- name: auto_escape
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Auto Escape
|
||||
zh_Hans: 自动转义
|
||||
human_description:
|
||||
en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes.
|
||||
zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。
|
||||
llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending.
|
||||
form: form
|
||||
|
|
@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = {
|
|||
"seed_resize_from_w": -1,
|
||||
|
||||
# Samplers
|
||||
# "sampler_name": "DPM++ 2M",
|
||||
"sampler_name": "DPM++ 2M",
|
||||
# "scheduler": "",
|
||||
# "sampler_index": "Automatic",
|
||||
|
||||
|
|
@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
|
|||
return [d['model_name'] for d in response.json()]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_sample_methods(self) -> list[str]:
|
||||
"""
|
||||
get sample method
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get('base_url', None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return [d['name'] for d in response.json()]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
|
@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool):
|
|||
label=I18nObject(en_US=i, zh_Hans=i)
|
||||
) for i in models])
|
||||
)
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
sample_methods = self.get_sample_methods()
|
||||
if len(sample_methods) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(name='sampler_name',
|
||||
label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
|
||||
human_description=I18nObject(
|
||||
en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档',
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
required=True,
|
||||
default=sample_methods[0],
|
||||
options=[ToolParameterOption(
|
||||
value=i,
|
||||
label=I18nObject(en_US=i, zh_Hans=i)
|
||||
) for i in sample_methods])
|
||||
)
|
||||
return parameters
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class ApiTool(Tool):
|
|||
path_params[parameter['name']] = value
|
||||
|
||||
elif parameter['in'] == 'query':
|
||||
params[parameter['name']] = value
|
||||
if value !='': params[parameter['name']] = value
|
||||
|
||||
elif parameter['in'] == 'cookie':
|
||||
cookies[parameter['name']] = value
|
||||
|
|
|
|||
|
|
@ -11,15 +11,6 @@ from core.workflow.nodes.base_node import BaseNode
|
|||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
|
||||
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
|
||||
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
|
||||
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
|
||||
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
|
||||
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
|
||||
MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_data_cls = CodeNodeData
|
||||
|
|
@ -97,8 +88,9 @@ class CodeNode(BaseNode):
|
|||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > MAX_STRING_LENGTH:
|
||||
raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters')
|
||||
if len(value) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(f'The length of output variable `{variable}` must be'
|
||||
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} characters')
|
||||
|
||||
return value.replace('\x00', '')
|
||||
|
||||
|
|
@ -115,13 +107,15 @@ class CodeNode(BaseNode):
|
|||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > MAX_NUMBER or value < MIN_NUMBER:
|
||||
raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.')
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise ValueError(f'Output variable `{variable}` is out of range,'
|
||||
f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.')
|
||||
|
||||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split('.')[1]) > MAX_PRECISION:
|
||||
raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.')
|
||||
if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
raise ValueError(f'Output variable `{variable}` has too high precision,'
|
||||
f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.')
|
||||
|
||||
return value
|
||||
|
||||
|
|
@ -134,8 +128,8 @@ class CodeNode(BaseNode):
|
|||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
if depth > MAX_DEPTH:
|
||||
raise ValueError("Depth limit reached, object too deep.")
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
if output_schema is None:
|
||||
|
|
@ -235,9 +229,10 @@ class CodeNode(BaseNode):
|
|||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
|
|
@ -257,9 +252,10 @@ class CodeNode(BaseNode):
|
|||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
|
|
@ -279,9 +275,10 @@ class CodeNode(BaseNode):
|
|||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
|
|
|||
|
|
@ -18,11 +18,6 @@ from core.workflow.nodes.http_request.entities import (
|
|||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
|
||||
READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE
|
||||
MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE
|
||||
|
||||
|
||||
class HttpExecutorResponse:
|
||||
headers: dict[str, str]
|
||||
|
|
@ -237,16 +232,14 @@ class HttpExecutor:
|
|||
else:
|
||||
raise ValueError(f'Invalid response type {type(response)}')
|
||||
|
||||
if executor_response.is_file:
|
||||
if executor_response.size > MAX_BINARY_SIZE:
|
||||
raise ValueError(
|
||||
f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.'
|
||||
)
|
||||
else:
|
||||
if executor_response.size > MAX_TEXT_SIZE:
|
||||
raise ValueError(
|
||||
f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.'
|
||||
)
|
||||
threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \
|
||||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
)
|
||||
|
||||
return executor_response
|
||||
|
||||
|
|
|
|||
|
|
@ -6372,13 +6372,13 @@ semver = ["semver (>=3.0.2)"]
|
|||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.3.4"
|
||||
version = "2.4.0"
|
||||
description = "Settings management using Pydantic"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"},
|
||||
{file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"},
|
||||
{file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"},
|
||||
{file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -6386,6 +6386,7 @@ pydantic = ">=2.7.0"
|
|||
python-dotenv = ">=0.21.0"
|
||||
|
||||
[package.extras]
|
||||
azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"]
|
||||
toml = ["tomli (>=2.0.1)"]
|
||||
yaml = ["pyyaml (>=6.0.1)"]
|
||||
|
||||
|
|
@ -9633,4 +9634,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "69c20af8ecacced3cca092662223a1511acaf65cb2616a5a1e38b498223463e0"
|
||||
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
|
||||
|
|
|
|||
|
|
@ -76,8 +76,6 @@ exclude = [
|
|||
"migrations/**/*",
|
||||
"services/**/*.py",
|
||||
"tasks/**/*.py",
|
||||
"tests/**/*.py",
|
||||
"configs/**/*.py",
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
|
|
@ -162,7 +160,7 @@ pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
|
|||
psycopg2-binary = "~2.9.6"
|
||||
pycryptodome = "3.19.1"
|
||||
pydantic = "~2.8.2"
|
||||
pydantic-settings = "~2.3.4"
|
||||
pydantic-settings = "~2.4.0"
|
||||
pydantic_extra_types = "~2.9.0"
|
||||
pyjwt = "~2.8.0"
|
||||
pypdfium2 = "~4.17.0"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ class OpsService:
|
|||
# decrypt_token and obfuscated_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
|
||||
if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')):
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
|
||||
decrypt_tracing_config['project_key'] = project_key
|
||||
|
||||
decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
|
||||
|
||||
trace_config_data.tracing_config = decrypt_tracing_config
|
||||
|
|
@ -37,7 +41,7 @@ class OpsService:
|
|||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider != None:
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
|
|
@ -51,6 +55,9 @@ class OpsService:
|
|||
if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider):
|
||||
return {"error": "Invalid Credentials"}
|
||||
|
||||
# get project key
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
|
|
@ -62,6 +69,8 @@ class OpsService:
|
|||
# get tenant id
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
|
||||
if tracing_provider == 'langfuse':
|
||||
tracing_config['project_key'] = project_key
|
||||
trace_config_data = TraceAppConfig(
|
||||
app_id=app_id,
|
||||
tracing_provider=tracing_provider,
|
||||
|
|
|
|||
|
|
@ -22,23 +22,20 @@ from anthropic.types import (
|
|||
)
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockAnthropicClass:
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id='msg-123',
|
||||
type='message',
|
||||
role='assistant',
|
||||
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
|
||||
id="msg-123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
|
||||
model=model,
|
||||
stop_reason='stop_sequence',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
stop_reason="stop_sequence",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -46,52 +43,43 @@ class MockAnthropicClass:
|
|||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
yield MessageStartEvent(
|
||||
type='message_start',
|
||||
type="message_start",
|
||||
message=Message(
|
||||
id='msg-123',
|
||||
id="msg-123",
|
||||
content=[],
|
||||
role='assistant',
|
||||
role="assistant",
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type='message',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
type="message",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
),
|
||||
)
|
||||
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
yield ContentBlockDeltaEvent(
|
||||
type='content_block_delta',
|
||||
delta=TextDelta(text=full_response_text[i], type='text_delta'),
|
||||
index=index
|
||||
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type='message_delta',
|
||||
delta=Delta(
|
||||
stop_reason='stop_sequence'
|
||||
),
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=1
|
||||
)
|
||||
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type='message_stop')
|
||||
yield MessageStopEvent(type="message_stop")
|
||||
|
||||
def mocked_anthropic(self: Messages, *,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
def mocked_anthropic(
|
||||
self: Messages,
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError('Invalid API key')
|
||||
raise anthropic.AuthenticationError("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
|
||||
|
|
@ -102,7 +90,7 @@ class MockAnthropicClass:
|
|||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
|
|
|
|||
|
|
@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
|
|||
from google.generativeai.types import GenerateContentResponse
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ''
|
||||
current_api_key = ""
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = 'it\'s google!'
|
||||
full_response_text = "it's google!"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
|
||||
|
||||
class MockGoogleResponseCandidateClass:
|
||||
finish_reason = 'stop'
|
||||
finish_reason = "stop"
|
||||
|
||||
@property
|
||||
def content(self) -> gag_content.Content:
|
||||
return gag_content.Content(
|
||||
parts=[
|
||||
gag_content.Part(text='it\'s google!')
|
||||
]
|
||||
)
|
||||
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
|
||||
|
||||
|
||||
class MockGoogleClass:
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(self: GenerativeModel,
|
||||
def generate_content(
|
||||
self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
|
|
@ -79,21 +62,21 @@ class MockGoogleClass:
|
|||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception('Invalid API key')
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return 'it\'s google!'
|
||||
|
||||
return "it's google!"
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
|
|
@ -121,7 +104,8 @@ class MockGoogleClass:
|
|||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
|
|
@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
|||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
|
|
|||
|
|
@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
|
|||
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
|
||||
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
|
|
|||
|
|
@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
|
|||
details=Details(
|
||||
finish_reason="length",
|
||||
generated_tokens=6,
|
||||
tokens=[
|
||||
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
|
||||
]
|
||||
)
|
||||
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
|
|||
|
||||
for i in range(0, len(full_text)):
|
||||
response = TextGenerationStreamResponse(
|
||||
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
)
|
||||
response.generated_text = full_text[i]
|
||||
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
|
||||
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
|
||||
|
||||
yield response
|
||||
|
||||
def text_generation(self: InferenceClient, prompt: str, *,
|
||||
stream: Literal[False] = ...,
|
||||
model: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
def text_generation(
|
||||
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
|
||||
# check if key is valid
|
||||
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
|
||||
raise BadRequestError('Invalid API key')
|
||||
|
||||
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
|
||||
raise BadRequestError("Invalid API key")
|
||||
|
||||
if model is None:
|
||||
raise BadRequestError('Invalid model')
|
||||
|
||||
raise BadRequestError("Invalid model")
|
||||
|
||||
if stream:
|
||||
return MockHuggingfaceChatClass.generate_create_stream(model)
|
||||
return MockHuggingfaceChatClass.generate_create_sync(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ class MockTEIClass:
|
|||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||
if 'rerank' in model_name:
|
||||
model_type = 'reranker'
|
||||
if "rerank" in model_name:
|
||||
model_type = "reranker"
|
||||
else:
|
||||
model_type = 'embedding'
|
||||
model_type = "embedding"
|
||||
|
||||
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||
|
||||
|
|
@ -17,16 +17,16 @@ class MockTEIClass:
|
|||
# Use space as token separator, and split the text into tokens
|
||||
tokenized_texts = []
|
||||
for text in texts:
|
||||
tokens = text.split(' ')
|
||||
tokens = text.split(" ")
|
||||
current_index = 0
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
s_token = {
|
||||
'id': idx,
|
||||
'text': token,
|
||||
'special': False,
|
||||
'start': current_index,
|
||||
'stop': current_index + len(token),
|
||||
"id": idx,
|
||||
"text": token,
|
||||
"special": False,
|
||||
"start": current_index,
|
||||
"stop": current_index + len(token),
|
||||
}
|
||||
current_index += len(token) + 1
|
||||
tokenized_text.append(s_token)
|
||||
|
|
@ -55,18 +55,18 @@ class MockTEIClass:
|
|||
embedding = [0.1] * 768
|
||||
embeddings.append(
|
||||
{
|
||||
'object': 'embedding',
|
||||
'embedding': embedding,
|
||||
'index': idx,
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": idx,
|
||||
}
|
||||
)
|
||||
return {
|
||||
'object': 'list',
|
||||
'data': embeddings,
|
||||
'model': 'MODEL_NAME',
|
||||
'usage': {
|
||||
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
'total_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
"object": "list",
|
||||
"data": embeddings,
|
||||
"model": "MODEL_NAME",
|
||||
"usage": {
|
||||
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
"total_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -83,9 +83,9 @@ class MockTEIClass:
|
|||
for idx, text in enumerate(texts):
|
||||
reranked_docs.append(
|
||||
{
|
||||
'index': idx,
|
||||
'text': text,
|
||||
'score': 0.9,
|
||||
"index": idx,
|
||||
"text": text,
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
# For mock, only return the first document
|
||||
|
|
|
|||
|
|
@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
|
|||
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
|
||||
|
||||
|
||||
def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
|
||||
def mock_openai(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
|
@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
|
|||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_openai_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_openai(monkeypatch, methods=methods)
|
||||
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
unpatch()
|
||||
|
|
|
|||
|
|
@ -43,62 +43,64 @@ class MockChatClass:
|
|||
if not functions or len(functions) == 0:
|
||||
return None
|
||||
function: completion_create_params.Function = functions[0]
|
||||
function_name = function['name']
|
||||
function_description = function['description']
|
||||
function_parameters = function['parameters']
|
||||
function_parameters_type = function_parameters['type']
|
||||
if function_parameters_type != 'object':
|
||||
function_name = function["name"]
|
||||
function_description = function["description"]
|
||||
function_parameters = function["parameters"]
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters['properties']
|
||||
function_parameters_required = function_parameters['required']
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter['type']
|
||||
if parameter_type == 'string':
|
||||
if 'enum' in parameter:
|
||||
if len(parameter['enum']) == 0:
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter['enum'][0]
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = 'kawaii'
|
||||
elif parameter_type == 'integer':
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == 'number':
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == 'boolean':
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool = tools[0]
|
||||
|
||||
if 'type' in tools and tools['type'] != 'function':
|
||||
if "type" in tools and tools["type"] != "function":
|
||||
return None
|
||||
|
||||
function = tool['function']
|
||||
function = tool["function"]
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
list_tool_calls.append(ChatCompletionMessageToolCall(
|
||||
id='sakurajima-mai',
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type='function'
|
||||
))
|
||||
|
||||
list_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id="sakurajima-mai",
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
return list_tool_calls
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_sync(
|
||||
model: str,
|
||||
|
|
@ -111,30 +113,27 @@ class MockChatClass:
|
|||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
return _ChatCompletion(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
_ChatCompletionChoice(
|
||||
finish_reason='content_filter',
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content='elaina',
|
||||
role='assistant',
|
||||
function_call=function_call,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
|
||||
),
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion',
|
||||
system_fingerprint='',
|
||||
object="chat.completion",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_stream(
|
||||
model: str,
|
||||
|
|
@ -150,36 +149,40 @@ class MockChatClass:
|
|||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content='',
|
||||
content="",
|
||||
function_call=ChoiceDeltaFunctionCall(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
) if function_call else None,
|
||||
role='assistant',
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id='misaka-mikoto',
|
||||
id="misaka-mikoto",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_calls[0].function.name,
|
||||
arguments=tool_calls[0].function.arguments,
|
||||
),
|
||||
type='function'
|
||||
type="function",
|
||||
)
|
||||
] if tool_calls and len(tool_calls) > 0 else None
|
||||
]
|
||||
if tool_calls and len(tool_calls) > 0
|
||||
else None,
|
||||
),
|
||||
finish_reason='function_call',
|
||||
finish_reason="function_call",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
|
|
@ -188,30 +191,45 @@ class MockChatClass:
|
|||
)
|
||||
else:
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content=full_text[i],
|
||||
role='assistant',
|
||||
role="assistant",
|
||||
),
|
||||
finish_reason='content_filter',
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def chat_create(self: Completions, *,
|
||||
def chat_create(
|
||||
self: Completions,
|
||||
*,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: Union[str,Literal[
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
],
|
||||
],
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
|
||||
|
|
@ -220,24 +238,32 @@ class MockChatClass:
|
|||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt35", "gpt-4v", "gpt-35-turbo"
|
||||
]
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if stream:
|
||||
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
|
||||
class MockCompletionsClass:
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_sync(
|
||||
model: str
|
||||
) -> CompletionMessage:
|
||||
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
|
||||
return CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
|
|
@ -38,13 +36,11 @@ class MockCompletionsClass:
|
|||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_stream(
|
||||
model: str
|
||||
) -> Generator[CompletionMessage, None, None]:
|
||||
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
|
|
@ -76,46 +72,59 @@ class MockCompletionsClass:
|
|||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text=full_text[i],
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="content_filter"
|
||||
)
|
||||
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
|
||||
],
|
||||
)
|
||||
|
||||
def completion_create(self: Completions, *, model: Union[
|
||||
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001",
|
||||
"text-ada-001"],
|
||||
def completion_create(
|
||||
self: Completions,
|
||||
*,
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
],
|
||||
],
|
||||
prompt: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt-35-turbo-instruct"
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
]
|
||||
azure_openai_models = ["gpt-35-turbo-instruct"]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
|
||||
if not prompt:
|
||||
raise BadRequestError('Invalid prompt')
|
||||
raise BadRequestError("Invalid prompt")
|
||||
if stream:
|
||||
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
|
||||
|
||||
class MockModerationClass:
|
||||
def moderation_create(self: Moderations,*,
|
||||
def moderation_create(
|
||||
self: Moderations,
|
||||
*,
|
||||
input: Union[str, list[str]],
|
||||
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> ModerationCreateResponse:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
for text in input:
|
||||
result = []
|
||||
if 'kill' in text:
|
||||
if "kill" in text:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
|
||||
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
|
||||
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
|
||||
"harassment": 1.0,
|
||||
"harassment/threatening": 1.0,
|
||||
"hate": 1.0,
|
||||
"hate/threatening": 1.0,
|
||||
"self-harm": 1.0,
|
||||
"self-harm/instructions": 1.0,
|
||||
"self-harm/intent": 1.0,
|
||||
"sexual": 1.0,
|
||||
"sexual/minors": 1.0,
|
||||
"violence": 1.0,
|
||||
"violence/graphic": 1.0,
|
||||
}
|
||||
|
||||
result.append(Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
)
|
||||
)
|
||||
else:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
|
||||
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
|
||||
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
|
||||
"harassment": 0.0,
|
||||
"harassment/threatening": 0.0,
|
||||
"hate": 0.0,
|
||||
"hate/threatening": 0.0,
|
||||
"self-harm": 0.0,
|
||||
"self-harm/instructions": 0.0,
|
||||
"self-harm/intent": 0.0,
|
||||
"sexual": 0.0,
|
||||
"sexual/minors": 0.0,
|
||||
"violence": 0.0,
|
||||
"violence/graphic": 0.0,
|
||||
}
|
||||
result.append(Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
)
|
||||
)
|
||||
|
||||
return ModerationCreateResponse(
|
||||
id='shiroii kuloko',
|
||||
model=model,
|
||||
results=result
|
||||
)
|
||||
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)
|
||||
|
|
|
|||
|
|
@ -6,17 +6,18 @@ from openai.types.model import Model
|
|||
|
||||
class MockModelClass:
|
||||
"""
|
||||
mock class for openai.models.Models
|
||||
mock class for openai.models.Models
|
||||
"""
|
||||
|
||||
def list(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
|
||||
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
|
||||
created=int(time()),
|
||||
object='model',
|
||||
owned_by='organization:org-123',
|
||||
object="model",
|
||||
owned_by="organization:org-123",
|
||||
)
|
||||
]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
|
||||
|
||||
class MockSpeech2TextClass:
|
||||
def speech2text_create(self: Transcriptions,
|
||||
def speech2text_create(
|
||||
self: Transcriptions,
|
||||
*,
|
||||
file: FileTypes,
|
||||
model: Union[str, Literal["whisper-1"]],
|
||||
|
|
@ -17,14 +18,12 @@ class MockSpeech2TextClass:
|
|||
prompt: str | NotGiven = NOT_GIVEN,
|
||||
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
|
||||
temperature: float | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> Transcription:
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
|
||||
return Transcription(
|
||||
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
)
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
|
||||
|
|
|
|||
|
|
@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
|||
|
||||
|
||||
class MockXinferenceClass:
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if 'generate' == model_uid:
|
||||
def get_chat_model(
|
||||
self: Client, model_uid: str
|
||||
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if "generate" == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'chat' == model_uid:
|
||||
if "chat" == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'embedding' == model_uid:
|
||||
if "embedding" == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'rerank' == model_uid:
|
||||
if "rerank" == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
response = Response()
|
||||
if 'v1/models/' in url:
|
||||
if "v1/models/" in url:
|
||||
# get model uid
|
||||
model_uid = url.split('/')[-1] or ''
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||
model_uid = url.split("/")[-1] or ""
|
||||
if not re.match(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
||||
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
if model_uid in ['generate', 'chat']:
|
||||
|
||||
if model_uid in ["generate", "chat"]:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
|
|
@ -75,12 +78,12 @@ class MockXinferenceClass:
|
|||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif model_uid == 'embedding':
|
||||
|
||||
elif model_uid == "embedding":
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
|
|
@ -93,51 +96,48 @@ class MockXinferenceClass:
|
|||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif 'v1/cluster/auth' in url:
|
||||
|
||||
elif "v1/cluster/auth" in url:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"auth": true
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
|
||||
|
||||
def rerank(
|
||||
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'rerank':
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "rerank"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'index': i,
|
||||
'document': doc,
|
||||
'relevance_score': 0.9
|
||||
}
|
||||
for i, doc in enumerate(documents[:top_n])
|
||||
"results": [
|
||||
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(
|
||||
self: RESTfulGenerateModelHandle,
|
||||
input: Union[str, list[str]],
|
||||
**kwargs
|
||||
) -> dict:
|
||||
|
||||
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'embedding':
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "embedding"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
|
@ -147,32 +147,27 @@ class MockXinferenceClass:
|
|||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(
|
||||
index=i,
|
||||
object="embedding",
|
||||
embedding=[1919.810 for _ in range(768)]
|
||||
)
|
||||
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(
|
||||
prompt_tokens=ipt_len,
|
||||
total_tokens=ipt_len
|
||||
)
|
||||
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
|
|
|||
|
|
@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL
|
|||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1.2',
|
||||
model="claude-instant-1.2",
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -98,18 +79,14 @@ def test_get_num_tokens():
|
|||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
|
|
|
|||
|
|
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv
|
|||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_anthropic_mock):
|
||||
provider = AnthropicProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi
|
|||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': 'invalid_key',
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": "invalid_key",
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
@ -58,14 +56,7 @@ def test_get_num_tokens():
|
|||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
|
|
|||
|
|
@ -17,111 +17,99 @@ def test_predefined_models():
|
|||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
}
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_with_system_message():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='请记住你是Kasumi。'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='现在告诉我你是谁?'
|
||||
)
|
||||
SystemPromptMessage(content="请记住你是Kasumi。"),
|
||||
UserPromptMessage(content="现在告诉我你是谁?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -131,34 +119,31 @@ def test_invoke_stream_model():
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'with_search_enhance': True,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"with_search_enhance": True,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
|
|
@ -166,25 +151,22 @@ def test_invoke_with_search():
|
|||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
total_message += chunk.delta.message.content
|
||||
|
||||
assert '不' not in total_message
|
||||
assert "不" not in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 9
|
||||
assert response == 9
|
||||
|
|
|
|||
|
|
@ -10,14 +10,6 @@ def test_validate_provider_credentials():
|
|||
provider = BaichuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})
|
||||
|
|
|
|||
|
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
|||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -30,44 +22,40 @@ def test_invoke_model():
|
|||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
|
|
@ -92,8 +80,8 @@ def test_max_chunks():
|
|||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
assert len(result.embeddings) == 22
|
||||
|
|
|
|||
|
|
@ -13,77 +13,63 @@ def test_validate_credentials():
|
|||
model = BedrockLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
}
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens_to_sample': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens_to_sample': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -100,20 +86,18 @@ def test_get_num_tokens():
|
|||
model = BedrockLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
credentials = {
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
|
|
|
|||
|
|
@ -10,14 +10,12 @@ def test_validate_provider_credentials():
|
|||
provider = BedrockProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,79 +23,64 @@ def test_predefined_models():
|
|||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock):
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
|
||||
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='波士顿天气如何?'
|
||||
)
|
||||
UserPromptMessage(content="波士顿天气如何?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=True,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
|
||||
call: LLMResultChunk = None
|
||||
chunks = []
|
||||
|
||||
|
|
@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock):
|
|||
break
|
||||
|
||||
assert call is not None
|
||||
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='What is the weather like in San Francisco?'
|
||||
)
|
||||
],
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
assert response.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
assert response.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
assert num_tokens == 21
|
||||
|
|
|
|||
|
|
@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
|
|||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = ChatGLMProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
|
|
|
|||
|
|
@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model():
|
|||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_validate_credentials_for_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
credentials = {
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
model="command-light",
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 1
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 1},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
|
||||
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
|
@ -109,28 +71,24 @@ def test_invoke_chat_model():
|
|||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'p': 0.99,
|
||||
'presence_penalty': 0.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'max_tokens': 10
|
||||
"temperature": 0.0,
|
||||
"p": 0.99,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
|
@ -141,24 +99,17 @@ def test_invoke_stream_chat_model():
|
|||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
|
@ -177,32 +128,22 @@ def test_get_num_tokens():
|
|||
model = CohereLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 15
|
||||
|
|
@ -213,25 +154,17 @@ def test_fine_tuned_model():
|
|||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
|
@ -242,25 +175,17 @@ def test_fine_tuned_chat_model():
|
|||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
|
|
|||
|
|
@ -10,12 +10,6 @@ def test_validate_provider_credentials():
|
|||
provider = CohereProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
|
|
|||
|
|
@ -11,29 +11,17 @@ def test_validate_credentials():
|
|||
model = CohereRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="rerank-english-v2.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
|
|
@ -41,9 +29,9 @@ def test_invoke_model():
|
|||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
|
||||
"is the capital of the United States. It is a federal district. The President of the USA and many major "
|
||||
"national government offices are in the territory. This makes it the political center of the United "
|
||||
"States of America."
|
||||
"States of America.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
|
|
|||
|
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
|||
model = CohereTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -30,17 +22,10 @@ def test_invoke_model():
|
|||
model = CohereTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
@ -52,14 +37,9 @@ def test_get_num_tokens():
|
|||
model = CohereTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider
|
|||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_google_mock):
|
||||
provider = GoogleProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'google_api_key': os.environ.get('GOOGLE_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})
|
||||
|
|
|
|||
|
|
@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa
|
|||
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='fake-model',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
model="fake-model",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -286,18 +264,14 @@ def test_get_num_tokens():
|
|||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 7
|
||||
|
|
|
|||
|
|
@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model():
|
|||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model():
|
|||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
@ -104,18 +98,15 @@ def test_get_num_tokens():
|
|||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
|
|
|||
|
|
@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe
|
|||
)
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'embedding'
|
||||
model_name = "embedding"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='reranker',
|
||||
model="reranker",
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
model_name = 'embedding'
|
||||
model_name = "embedding"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
|
|||
|
|
@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
|
|||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
model_name = "reranker"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
model_name = "reranker"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
|
|
|||
|
|
@ -14,19 +14,15 @@ def test_validate_credentials():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
}
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -34,23 +30,16 @@ def test_invoke_model():
|
|||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hi'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
|
|
@ -61,23 +50,15 @@ def test_invoke_stream_model():
|
|||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hi'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -93,19 +74,17 @@ def test_get_num_tokens():
|
|||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
|
|
|||
|
|
@ -10,16 +10,11 @@ def test_validate_provider_credentials():
|
|||
provider = HunyuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,19 +12,15 @@ def test_validate_credentials():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
}
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -32,47 +28,43 @@ def test_invoke_model():
|
|||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
|
|
@ -97,8 +89,8 @@ def test_max_chunks():
|
|||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
assert len(result.embeddings) == 22
|
||||
|
|
|
|||
|
|
@ -10,14 +10,6 @@ def test_validate_provider_credentials():
|
|||
provider = JinaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})
|
||||
|
|
|
|||
|
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
|||
model = JinaTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -30,15 +22,12 @@ def test_invoke_model():
|
|||
model = JinaTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
|
@ -50,14 +39,11 @@ def test_get_num_tokens():
|
|||
model = JinaTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 6
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""
|
||||
LocalAI Embedding Interface is temporarily unavailable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
LocalAI Embedding Interface is temporarily unavailable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -123,28 +102,21 @@ def test_invoke_stream_completion_model():
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -154,64 +126,48 @@ def test_invoke_stream_chat_model():
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
model="????",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
model="????",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
|
|
|||
|
|
@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-v2-m3',
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
model="bge-reranker-base",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
|
|
@ -45,43 +44,38 @@ def test_invoke_rerank_model():
|
|||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
|
||||
def test__invoke():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Valid invocation
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
|
|
@ -91,12 +85,12 @@ def test__invoke():
|
|||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
|
|
|
|||
|
|
@ -10,19 +10,9 @@ def test_validate_credentials():
|
|||
model = LocalAISpeech2text()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': 'invalid_url'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
|
|
@ -32,23 +22,21 @@ def test_invoke_model():
|
|||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
model="whisper-1",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
file=file,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
|
|
|
|||
|
|
@ -12,54 +12,47 @@ def test_validate_credentials():
|
|||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
model="embo-01",
|
||||
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 16
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
|
|
|||
|
|
@ -17,79 +17,70 @@ def test_predefined_models():
|
|||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': 'invalid_key'
|
||||
}
|
||||
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5-chat',
|
||||
model="abab5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
|
@ -99,34 +90,31 @@ def test_invoke_stream_model():
|
|||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'plugin_web_search': True,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
|
|
@ -134,25 +122,22 @@ def test_invoke_with_search():
|
|||
total_message += chunk.delta.message.content
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
assert '参考资料' in total_message
|
||||
assert "参考资料" in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 30
|
||||
assert response == 30
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ def test_validate_provider_credentials():
|
|||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': 'hahahaha',
|
||||
'minimax_group_id': '123',
|
||||
"minimax_api_key": "hahahaha",
|
||||
"minimax_group_id": "123",
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue