Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/tests/integration_tests/workflow/nodes/test_code.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
takatost 2024-08-24 17:26:44 +08:00
commit 4771e85630
276 changed files with 5714 additions and 6510 deletions

View File

@ -1,3 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()
dify_config = DifyConfig()

View File

@ -1,4 +1,3 @@
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig
@ -24,44 +23,16 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file='.env',
env_file_encoding='utf-8',
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra='ignore',
extra="ignore",
)
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
# Before adding any config,
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.

View File

@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings):
"""
Deployment configs
"""
APPLICATION_NAME: str = Field(
description='application name',
default='langgenius/dify',
description="application name",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="whether to enable debug mode.",
default=False,
)
TESTING: bool = Field(
description='',
description="",
default=False,
)
EDITION: str = Field(
description='deployment edition',
default='SELF_HOSTED',
description="deployment edition",
default="SELF_HOSTED",
)
DEPLOY_ENV: str = Field(
description='deployment environment, default to PRODUCTION.',
default='PRODUCTION',
description="deployment environment, default to PRODUCTION.",
default="PRODUCTION",
)

View File

@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings):
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description='whether to enable enterprise features.'
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
description="whether to enable enterprise features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description='whether to allow replacing enterprise logo.',
description="whether to allow replacing enterprise logo.",
default=False,
)

View File

@ -8,27 +8,28 @@ class NotionConfig(BaseSettings):
"""
Notion integration configs
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description='Notion client ID',
description="Notion client ID",
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description='Notion client secret key',
description="Notion client secret key",
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description='Notion integration type, default to None, available values: internal.',
description="Notion integration type, default to None, available values: internal.",
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description='Notion internal secret key',
description="Notion internal secret key",
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description='Notion integration token',
description="Notion integration token",
default=None,
)

View File

@ -8,17 +8,18 @@ class SentryConfig(BaseSettings):
"""
Sentry configs
"""
SENTRY_DSN: Optional[str] = Field(
description='Sentry DSN',
description="Sentry DSN",
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description='Sentry trace sample rate',
description="Sentry trace sample rate",
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description='Sentry profiles sample rate',
description="Sentry profiles sample rate",
default=1.0,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import AliasChoices, Field, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings):
"""
Secret Key configs
"""
SECRET_KEY: Optional[str] = Field(
description='Your App secret key will be used for securely signing the session cookie'
'Make sure you are changing this key for your deployment with a strong key.'
'You can generate a strong key using `openssl rand -base64 42`.'
'Alternatively you can set it with `SECRET_KEY` environment variable.',
description="Your App secret key will be used for securely signing the session cookie"
"Make sure you are changing this key for your deployment with a strong key."
"You can generate a strong key using `openssl rand -base64 42`."
"Alternatively you can set it with `SECRET_KEY` environment variable.",
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
description="Expiry time in hours for reset token",
default=24,
)
@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings):
"""
App Execution configs
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description='execution timeout in seconds for app execution',
description="execution timeout in seconds for app execution",
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
description="max active request per app, 0 means unlimited",
default=0,
)
@ -42,14 +44,55 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
Code Execution Sandbox configs
"""
CODE_EXECUTION_ENDPOINT: str = Field(
description='endpoint URL of code execution servcie',
default='http://sandbox:8194',
description="endpoint URL of code execution servcie",
default="http://sandbox:8194",
)
CODE_EXECUTION_API_KEY: str = Field(
description='API key for code execution service',
default='dify-sandbox',
description="API key for code execution service",
default="dify-sandbox",
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=1000,
)
@ -57,28 +100,27 @@ class EndpointConfig(BaseSettings):
"""
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description='The backend URL prefix of the console API.'
'used to concatenate the login authorization callback or notion integration callback.',
default='',
description="The backend URL prefix of the console API."
"used to concatenate the login authorization callback or notion integration callback.",
default="",
)
CONSOLE_WEB_URL: str = Field(
description='The front-end URL prefix of the console web.'
'used to concatenate some front-end addresses and for CORS configuration use.',
default='',
description="The front-end URL prefix of the console web."
"used to concatenate some front-end addresses and for CORS configuration use.",
default="",
)
SERVICE_API_URL: str = Field(
description='Service API Url prefix.'
'used to display Service API Base Url to the front-end.',
default='',
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description='WebApp Url prefix.'
'used to display WebAPP API Base Url to the front-end.',
default='',
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="",
)
@ -86,17 +128,18 @@ class FileAccessConfig(BaseSettings):
"""
File Access configs
"""
FILES_URL: str = Field(
description='File preview or download Url prefix.'
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
'Url is signed and has expiration time.',
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
description="File preview or download Url prefix."
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
alias_priority=1,
default='',
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description='timeout in seconds for file accessing',
description="timeout in seconds for file accessing",
default=300,
)
@ -105,23 +148,24 @@ class FileUploadConfig(BaseSettings):
"""
File Uploading configs
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description='size limit in Megabytes for uploading files',
description="size limit in Megabytes for uploading files",
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description='batch size limit for uploading files',
description="batch size limit for uploading files",
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description='image file size limit in Megabytes for uploading files',
description="image file size limit in Megabytes for uploading files",
default=10,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description='', # todo: to be clarified
description="", # todo: to be clarified
default=20,
)
@ -130,45 +174,82 @@ class HttpConfig(BaseSettings):
"""
HTTP configs
"""
API_COMPRESSION_ENABLED: bool = Field(
description='whether to enable HTTP response compression of gzip',
description="whether to enable HTTP response compression of gzip",
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
default='',
description="",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
)
@computed_field
@property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default='*',
description="",
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
default="*",
)
@computed_field
@property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
description="",
default=300,
)
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
description="",
default=600,
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="HTTPS URL for SSRF proxy",
default=None,
)
class InnerAPIConfig(BaseSettings):
"""
Inner API configs
"""
INNER_API: bool = Field(
description='whether to enable the inner API',
description="whether to enable the inner API",
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description='The inner API key is used to authenticate the inner API',
description="The inner API key is used to authenticate the inner API",
default=None,
)
@ -179,28 +260,27 @@ class LoggingConfig(BaseSettings):
"""
LOG_LEVEL: str = Field(
description='Log output level, default to INFO.'
'It is recommended to set it to ERROR for production.',
default='INFO',
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO",
)
LOG_FILE: Optional[str] = Field(
description='logging output file path',
description="logging output file path",
default=None,
)
LOG_FORMAT: str = Field(
description='log format',
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
description="log format",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: Optional[str] = Field(
description='log date format',
description="log date format",
default=None,
)
LOG_TZ: Optional[str] = Field(
description='specify log timezone, eg: America/New_York',
description="specify log timezone, eg: America/New_York",
default=None,
)
@ -209,8 +289,9 @@ class ModelLoadBalanceConfig(BaseSettings):
"""
Model load balance configs
"""
MODEL_LB_ENABLED: bool = Field(
description='whether to enable model load balancing',
description="whether to enable model load balancing",
default=False,
)
@ -219,8 +300,9 @@ class BillingConfig(BaseSettings):
"""
Platform Billing Configurations
"""
BILLING_ENABLED: bool = Field(
description='whether to enable billing',
description="whether to enable billing",
default=False,
)
@ -229,9 +311,10 @@ class UpdateConfig(BaseSettings):
"""
Update configs
"""
CHECK_UPDATE_URL: str = Field(
description='url for checking updates',
default='https://updates.dify.ai',
description="url for checking updates",
default="https://updates.dify.ai",
)
@ -241,47 +324,53 @@ class WorkflowConfig(BaseSettings):
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description='max execution steps in single workflow execution',
description="max execution steps in single workflow execution",
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description='max execution time in seconds in single workflow execution',
description="max execution time in seconds in single workflow execution",
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description='max depth of calling in single workflow execution',
description="max depth of calling in single workflow execution",
default=5,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings):
"""
oauth configs
"""
OAUTH_REDIRECT_PATH: str = Field(
description='redirect path for OAuth',
default='/console/api/oauth/authorize',
description="redirect path for OAuth",
default="/console/api/oauth/authorize",
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description='GitHub client id for OAuth',
description="GitHub client id for OAuth",
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description='GitHub client secret key for OAuth',
description="GitHub client secret key for OAuth",
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description='Google client id for OAuth',
description="Google client id for OAuth",
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description='Google client secret key for OAuth',
description="Google client secret key for OAuth",
default=None,
)
@ -291,9 +380,8 @@ class ModerationConfig(BaseSettings):
Moderation in app configs.
"""
# todo: to be clarified in usage and unit
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
description='buffer size for moderation',
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="buffer size for moderation",
default=300,
)
@ -304,7 +392,7 @@ class ToolConfig(BaseSettings):
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description='max age in seconds for tool icon caching',
description="max age in seconds for tool icon caching",
default=3600,
)
@ -315,52 +403,52 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description='default email address for sending from ',
description="default email address for sending from ",
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description='API key for Resend',
description="API key for Resend",
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description='API URL for Resend',
description="API URL for Resend",
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description='smtp server host',
description="smtp server host",
default=None,
)
SMTP_PORT: Optional[int] = Field(
description='smtp server port',
description="smtp server port",
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description='smtp server username',
description="smtp server username",
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description='smtp server password',
description="smtp server password",
default=None,
)
SMTP_USE_TLS: bool = Field(
description='whether to use TLS connection to smtp server',
description="whether to use TLS connection to smtp server",
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description='whether to use opportunistic TLS connection to smtp server',
description="whether to use opportunistic TLS connection to smtp server",
default=False,
)
@ -371,22 +459,22 @@ class RagEtlConfig(BaseSettings):
"""
ETL_TYPE: str = Field(
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
default='dify',
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
default="dify",
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description='source type for keyword data, default to `database`, available values are `database` .',
default='database',
description="source type for keyword data, default to `database`, available values are `database` .",
default="database",
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description='API URL for Unstructured',
description="API URL for Unstructured",
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description='API key for Unstructured',
description="API key for Unstructured",
default=None,
)
@ -397,12 +485,12 @@ class DataSetConfig(BaseSettings):
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description='interval in days for cleaning up dataset',
description="interval in days for cleaning up dataset",
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
description="whether to enable dataset operator",
default=False,
)
@ -413,7 +501,7 @@ class WorkspaceConfig(BaseSettings):
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description='workspaces invitation expiration in hours',
description="workspaces invitation expiration in hours",
default=72,
)
@ -424,80 +512,79 @@ class IndexingConfig(BaseSettings):
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description='max segmentation token length for indexing',
description="max segmentation token length for indexing",
default=1000,
)
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description='multi model send image format, support base64, url, default is base64',
default='base64',
description="multi model send image format, support base64, url, default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description='the time of the celery scheduler, default to 1 day',
description="the time of the celery scheduler, default to 1 day",
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
description="The heads of model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
description="The included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
description="The excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
description="The heads of tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
description="The included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
description="The excluded tools",
default="",
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class FeatureConfig(
@ -525,7 +612,6 @@ class FeatureConfig(
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description='',
default='gpt-3.5-turbo,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-instruct,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'text-davinci-003',
description="",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description='',
default='gpt-4,'
'gpt-4-turbo-preview,'
'gpt-4-turbo-2024-04-09,'
'gpt-4-1106-preview,'
'gpt-4-0125-preview,'
'gpt-3.5-turbo,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'gpt-3.5-turbo-instruct,'
'text-davinci-003',
description="",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003",
)
@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=200,
)
@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
"""
HOSTED_SPARK_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description='',
default='',
description="",
default="",
)
@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description='the mode for fetching app templates,'
' default to remote,'
' available values: remote, db, builtin',
default='remote',
description="the mode for fetching app templates,"
" default to remote,"
" available values: remote, db, builtin",
default="remote",
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description='the domain for fetching remote app templates',
default='https://tmpl.dify.ai',
description="the domain for fetching remote app templates",
default="https://tmpl.dify.ai",
)
@ -202,7 +202,6 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
# moderation
HostedModerationConfig,
):

View File

@ -28,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description='storage type,'
' default to `local`,'
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
default='local',
description="storage type,"
" default to `local`,"
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
default="local",
)
STORAGE_LOCAL_PATH: str = Field(
description='local storage path',
default='storage',
description="local storage path",
default="storage",
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description='vector store type',
description="vector store type",
default=None,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description='keyword store type',
default='jieba',
description="keyword store type",
default="jieba",
)
class DatabaseConfig:
DB_HOST: str = Field(
description='db host',
default='localhost',
description="db host",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description='db port',
description="db port",
default=5432,
)
DB_USERNAME: str = Field(
description='db username',
default='postgres',
description="db username",
default="postgres",
)
DB_PASSWORD: str = Field(
description='db password',
default='',
description="db password",
default="",
)
DB_DATABASE: str = Field(
description='db database',
default='dify',
description="db database",
default="dify",
)
DB_CHARSET: str = Field(
description='db charset',
default='',
description="db charset",
default="",
)
DB_EXTRAS: str = Field(
description='db extras options. Example: keepalives_idle=60&keepalives=1',
default='',
description="db extras options. Example: keepalives_idle=60&keepalives=1",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description='db uri scheme',
default='postgresql',
description="db uri scheme",
default="postgresql",
)
@computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}")
return (
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}"
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description='pool size of SqlAlchemy',
description="pool size of SqlAlchemy",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description='max overflows for SqlAlchemy',
description="max overflows for SqlAlchemy",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description='SqlAlchemy pool recycle',
description="SqlAlchemy pool recycle",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description='whether to enable pool pre-ping in SqlAlchemy',
description="whether to enable pool pre-ping in SqlAlchemy",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description='whether to enable SqlAlchemy echo',
description="whether to enable SqlAlchemy echo",
default=False,
)
@ -137,35 +137,38 @@ class DatabaseConfig:
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {
'pool_size': self.SQLALCHEMY_POOL_SIZE,
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
'connect_args': {'options': '-c timezone=UTC'},
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description='Celery backend, available values are `database`, `redis`',
default='database',
description="Celery backend, available values are `database`, `redis`",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description='CELERY_BROKER_URL',
description="CELERY_BROKER_URL",
default=None,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
@computed_field
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class MiddlewareConfig(
@ -174,7 +177,6 @@ class MiddlewareConfig(
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,
@ -183,7 +185,6 @@ class MiddlewareConfig(
TencentCloudCOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,

View File

@ -8,32 +8,33 @@ class RedisConfig(BaseSettings):
"""
Redis configs
"""
REDIS_HOST: str = Field(
description='Redis host',
default='localhost',
description="Redis host",
default="localhost",
)
REDIS_PORT: PositiveInt = Field(
description='Redis port',
description="Redis port",
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description='Redis username',
description="Redis username",
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description='Redis password',
description="Redis password",
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description='Redis database id, default to 0',
description="Redis database id, default to 0",
default=0,
)
REDIS_USE_SSL: bool = Field(
description='whether to use SSL for Redis connection',
description="whether to use SSL for Redis connection",
default=False,
)

View File

@ -10,31 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings):
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description='Aliyun OSS bucket name',
description="Aliyun OSS bucket name",
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description='Aliyun OSS access key',
description="Aliyun OSS access key",
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description='Aliyun OSS secret key',
description="Aliyun OSS secret key",
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description='Aliyun OSS endpoint URL',
description="Aliyun OSS endpoint URL",
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description='Aliyun OSS region',
description="Aliyun OSS region",
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description='Aliyun OSS authentication version',
description="Aliyun OSS authentication version",
default=None,
)

View File

@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
"""
S3_ENDPOINT: Optional[str] = Field(
description='S3 storage endpoint',
description="S3 storage endpoint",
default=None,
)
S3_REGION: Optional[str] = Field(
description='S3 storage region',
description="S3 storage region",
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description='S3 storage bucket name',
description="S3 storage bucket name",
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description='S3 storage access key',
description="S3 storage access key",
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description='S3 storage secret key',
description="S3 storage secret key",
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description='S3 storage address style',
default='auto',
description="S3 storage address style",
default="auto",
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description='whether to use aws managed IAM for S3',
description="whether to use aws managed IAM for S3",
default=False,
)

View File

@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description='Azure Blob account name',
description="Azure Blob account name",
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description='Azure Blob account key',
description="Azure Blob account key",
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description='Azure Blob container name',
description="Azure Blob container name",
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description='Azure Blob account URL',
description="Azure Blob account URL",
default=None,
)

View File

@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description='Google Cloud storage bucket name',
description="Google Cloud storage bucket name",
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description='Google Cloud storage service account json base64',
description="Google Cloud storage service account json base64",
default=None,
)

View File

@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings):
"""
OCI_ENDPOINT: Optional[str] = Field(
description='OCI storage endpoint',
description="OCI storage endpoint",
default=None,
)
OCI_REGION: Optional[str] = Field(
description='OCI storage region',
description="OCI storage region",
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description='OCI storage bucket name',
description="OCI storage bucket name",
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description='OCI storage access key',
description="OCI storage access key",
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description='OCI storage secret key',
description="OCI storage secret key",
default=None,
)

View File

@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description='Tencent Cloud COS bucket name',
description="Tencent Cloud COS bucket name",
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description='Tencent Cloud COS region',
description="Tencent Cloud COS region",
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description='Tencent Cloud COS secret id',
description="Tencent Cloud COS secret id",
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description='Tencent Cloud COS secret key',
description="Tencent Cloud COS secret key",
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description='Tencent Cloud COS scheme',
description="Tencent Cloud COS scheme",
default=None,
)

View File

@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel):
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The Access Key ID provided by Alibaba Cloud for authentication."
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
)
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None, description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
"""
CHROMA_HOST: Optional[str] = Field(
description='Chroma host',
description="Chroma host",
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description='Chroma port',
description="Chroma port",
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description='Chroma database',
description="Chroma database",
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description='Chroma database',
description="Chroma database",
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description='Chroma authentication provider',
description="Chroma authentication provider",
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description='Chroma authentication credentials',
description="Chroma authentication credentials",
default=None,
)

View File

@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
"""
MILVUS_HOST: Optional[str] = Field(
description='Milvus host',
description="Milvus host",
default=None,
)
MILVUS_PORT: PositiveInt = Field(
description='Milvus RestFul API port',
description="Milvus RestFul API port",
default=9091,
)
MILVUS_USER: Optional[str] = Field(
description='Milvus user',
description="Milvus user",
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description='Milvus password',
description="Milvus password",
default=None,
)
MILVUS_SECURE: bool = Field(
description='whether to use SSL connection for Milvus',
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description='Milvus database, default to `default`',
default='default',
description="Milvus database, default to `default`",
default="default",
)

View File

@ -1,4 +1,3 @@
from pydantic import BaseModel, Field, PositiveInt
@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel):
"""
MYSCALE_HOST: str = Field(
description='MyScale host',
default='localhost',
description="MyScale host",
default="localhost",
)
MYSCALE_PORT: PositiveInt = Field(
description='MyScale port',
description="MyScale port",
default=8123,
)
MYSCALE_USER: str = Field(
description='MyScale user',
default='default',
description="MyScale user",
default="default",
)
MYSCALE_PASSWORD: str = Field(
description='MyScale password',
default='',
description="MyScale password",
default="",
)
MYSCALE_DATABASE: str = Field(
description='MyScale database name',
default='default',
description="MyScale database name",
default="default",
)
MYSCALE_FTS_PARAMS: str = Field(
description='MyScale fts index parameters',
default='',
description="MyScale fts index parameters",
default="",
)

View File

@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
"""
OPENSEARCH_HOST: Optional[str] = Field(
description='OpenSearch host',
description="OpenSearch host",
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description='OpenSearch port',
description="OpenSearch port",
default=9200,
)
OPENSEARCH_USER: Optional[str] = Field(
description='OpenSearch user',
description="OpenSearch user",
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description='OpenSearch password',
description="OpenSearch password",
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description='whether to use SSL connection for OpenSearch',
description="whether to use SSL connection for OpenSearch",
default=False,
)

View File

@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
"""
ORACLE_HOST: Optional[str] = Field(
description='ORACLE host',
description="ORACLE host",
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
description='ORACLE port',
description="ORACLE port",
default=1521,
)
ORACLE_USER: Optional[str] = Field(
description='ORACLE user',
description="ORACLE user",
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description='ORACLE password',
description="ORACLE password",
default=None,
)
ORACLE_DATABASE: Optional[str] = Field(
description='ORACLE database',
description="ORACLE database",
default=None,
)

View File

@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
"""
PGVECTOR_HOST: Optional[str] = Field(
description='PGVector host',
description="PGVector host",
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
description='PGVector port',
description="PGVector port",
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description='PGVector user',
description="PGVector user",
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description='PGVector password',
description="PGVector password",
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description='PGVector database',
description="PGVector database",
default=None,
)

View File

@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description='PGVectoRS host',
description="PGVectoRS host",
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description='PGVectoRS port',
description="PGVectoRS port",
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description='PGVectoRS user',
description="PGVectoRS user",
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description='PGVectoRS password',
description="PGVectoRS password",
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description='PGVectoRS database',
description="PGVectoRS database",
default=None,
)

View File

@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
"""
QDRANT_URL: Optional[str] = Field(
description='Qdrant url',
description="Qdrant url",
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description='Qdrant api key',
description="Qdrant api key",
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description='Qdrant client timeout in seconds',
description="Qdrant client timeout in seconds",
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description='whether enable grpc support for Qdrant connection',
description="whether enable grpc support for Qdrant connection",
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description='Qdrant grpc port',
description="Qdrant grpc port",
default=6334,
)

View File

@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
"""
RELYT_HOST: Optional[str] = Field(
description='Relyt host',
description="Relyt host",
default=None,
)
RELYT_PORT: PositiveInt = Field(
description='Relyt port',
description="Relyt port",
default=9200,
)
RELYT_USER: Optional[str] = Field(
description='Relyt user',
description="Relyt user",
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description='Relyt password',
description="Relyt password",
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description='Relyt database',
default='default',
description="Relyt database",
default="default",
)

View File

@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description='Tencent Vector URL',
description="Tencent Vector URL",
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description='Tencent Vector API key',
description="Tencent Vector API key",
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description='Tencent Vector timeout in seconds',
description="Tencent Vector timeout in seconds",
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description='Tencent Vector username',
description="Tencent Vector username",
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description='Tencent Vector password',
description="Tencent Vector password",
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description='Tencent Vector sharding number',
description="Tencent Vector sharding number",
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description='Tencent Vector replicas',
description="Tencent Vector replicas",
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description='Tencent Vector Database',
description="Tencent Vector Database",
default=None,
)

View File

@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description='TiDB Vector host',
description="TiDB Vector host",
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description='TiDB Vector port',
description="TiDB Vector port",
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description='TiDB Vector user',
description="TiDB Vector user",
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description='TiDB Vector password',
description="TiDB Vector password",
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description='TiDB Vector database',
description="TiDB Vector database",
default=None,
)

View File

@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description='Weaviate endpoint URL',
description="Weaviate endpoint URL",
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description='Weaviate API key',
description="Weaviate API key",
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description='whether to enable gRPC for Weaviate connection',
description="whether to enable gRPC for Weaviate connection",
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description='Weaviate batch size',
description="Weaviate batch size",
default=100,
)

View File

@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
"""
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.1',
description="Dify version",
default="0.7.1",
)
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default='',
default="",
)

View File

@ -21,6 +21,18 @@ parameter_rules:
default: 1024
min: 1
max: 128000
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.06'
output: '0.06'

View File

@ -21,6 +21,18 @@ parameter_rules:
default: 1024
min: 1
max: 32000
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.024'
output: '0.024'

View File

@ -21,6 +21,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.012'
output: '0.012'

View File

@ -0,0 +1,9 @@
model: text-embedding-v3
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -444,6 +444,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason
)
)

View File

@ -21,6 +21,7 @@ class LangfuseConfig(BaseTracingConfig):
"""
public_key: str
secret_key: str
project_key: str
host: str = 'https://api.langfuse.com'
@field_validator("host")

View File

@ -419,3 +419,11 @@ class LangFuseDataTrace(BaseTraceInstance):
except Exception as e:
logger.debug(f"LangFuse API check failed: {str(e)}")
raise ValueError(f"LangFuse API check failed: {str(e)}")
def get_project_key(self):
try:
projects = self.langfuse_client.client.projects.get()
return projects.data[0].id
except Exception as e:
logger.debug(f"LangFuse get project key failed: {str(e)}")
raise ValueError(f"LangFuse get project key failed: {str(e)}")

View File

@ -38,7 +38,7 @@ provider_config_map = {
TracingProviderEnum.LANGFUSE.value: {
'config_class': LangfuseConfig,
'secret_keys': ['public_key', 'secret_key'],
'other_keys': ['host'],
'other_keys': ['host', 'project_key'],
'trace_instance': LangFuseDataTrace
},
TracingProviderEnum.LANGSMITH.value: {
@ -123,7 +123,6 @@ class OpsTraceManager:
for key in other_keys:
new_config[key] = decrypt_tracing_config.get(key, "")
return config_class(**new_config).model_dump()
@classmethod
@ -252,6 +251,19 @@ class OpsTraceManager:
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).api_check()
@staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_key()
class TraceTask:
def __init__(

View File

@ -614,7 +614,7 @@ class DatasetRetrieval:
top_k: int, score_threshold: float) -> list[Document]:
filter_documents = []
for document in all_documents:
if score_threshold and document.metadata['score'] >= score_threshold:
if score_threshold is None or document.metadata['score'] >= score_threshold:
filter_documents.append(document)
if not filter_documents:
return []

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View File

@ -0,0 +1,12 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class OneBotProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
if not credentials.get("ob11_http_url"):
raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.')

View File

@ -0,0 +1,35 @@
identity:
author: RockChinQ
name: onebot
label:
en_US: OneBot v11 Protocol
zh_Hans: OneBot v11 协议
description:
en_US: Unofficial OneBot v11 Protocol Tool
zh_Hans: 非官方 OneBot v11 协议工具
icon: icon.ico
credentials_for_provider:
ob11_http_url:
type: text-input
required: true
label:
en_US: HTTP URL
zh_Hans: HTTP URL
description:
en_US: Forward HTTP URL of OneBot v11
zh_Hans: OneBot v11 正向 HTTP URL
help:
en_US: Fill this with the HTTP URL of your OneBot server
zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL
access_token:
type: secret-input
required: false
label:
en_US: Access Token
zh_Hans: 访问令牌
description:
en_US: Access Token for OneBot v11 Protocol
zh_Hans: OneBot 协议访问令牌
help:
en_US: Fill this if you set a access token in your OneBot server
zh_Hans: 如果你在 OneBot 服务器中设置了 access token请填写此项

View File

@ -0,0 +1,64 @@
from typing import Any, Union
import requests
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class SendGroupMsg(BuiltinTool):
"""OneBot v11 Tool: Send Group Message"""
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# Get parameters
send_group_id = tool_parameters.get('group_id', '')
message = tool_parameters.get('message', '')
if not message:
return self.create_json_message(
{
'error': 'Message is empty.'
}
)
auto_escape = tool_parameters.get('auto_escape', False)
try:
url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg'
resp = requests.post(
url,
json={
'group_id': send_group_id,
'message': message,
'auto_escape': auto_escape
},
headers={
'Authorization': 'Bearer ' + self.runtime.credentials['access_token']
}
)
if resp.status_code != 200:
return self.create_json_message(
{
'error': f'Failed to send group message: {resp.text}'
}
)
return self.create_json_message(
{
'response': resp.json()
}
)
except Exception as e:
return self.create_json_message(
{
'error': f'Failed to send group message: {e}'
}
)

View File

@ -0,0 +1,46 @@
identity:
name: send_group_msg
author: RockChinQ
label:
en_US: Send Group Message
zh_Hans: 发送群消息
description:
human:
en_US: Send a message to a group
zh_Hans: 发送消息到群聊
llm: A tool for sending a message segment to a group
parameters:
- name: group_id
type: number
required: true
label:
en_US: Target Group ID
zh_Hans: 目标群 ID
human_description:
en_US: The group ID of the target group
zh_Hans: 目标群的群 ID
llm_description: The group ID of the target group
form: llm
- name: message
type: string
required: true
label:
en_US: Message
zh_Hans: 消息
human_description:
en_US: The message to send
zh_Hans: 要发送的消息。支持 CQ码需要同时设置 auto_escape 为 true
llm_description: The message to send
form: llm
- name: auto_escape
type: boolean
required: false
default: false
label:
en_US: Auto Escape
zh_Hans: 自动转义
human_description:
en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes.
zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。
llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending.
form: form

View File

@ -0,0 +1,64 @@
from typing import Any, Union
import requests
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class SendPrivateMsg(BuiltinTool):
"""OneBot v11 Tool: Send Private Message"""
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# Get parameters
send_user_id = tool_parameters.get('user_id', '')
message = tool_parameters.get('message', '')
if not message:
return self.create_json_message(
{
'error': 'Message is empty.'
}
)
auto_escape = tool_parameters.get('auto_escape', False)
try:
url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg'
resp = requests.post(
url,
json={
'user_id': send_user_id,
'message': message,
'auto_escape': auto_escape
},
headers={
'Authorization': 'Bearer ' + self.runtime.credentials['access_token']
}
)
if resp.status_code != 200:
return self.create_json_message(
{
'error': f'Failed to send private message: {resp.text}'
}
)
return self.create_json_message(
{
'response': resp.json()
}
)
except Exception as e:
return self.create_json_message(
{
'error': f'Failed to send private message: {e}'
}
)

View File

@ -0,0 +1,46 @@
identity:
name: send_private_msg
author: RockChinQ
label:
en_US: Send Private Message
zh_Hans: 发送私聊消息
description:
human:
en_US: Send a private message to a user
zh_Hans: 发送私聊消息给用户
llm: A tool for sending a message segment to a user in private chat
parameters:
- name: user_id
type: number
required: true
label:
en_US: Target User ID
zh_Hans: 目标用户 ID
human_description:
en_US: The user ID of the target user
zh_Hans: 目标用户的用户 ID
llm_description: The user ID of the target user
form: llm
- name: message
type: string
required: true
label:
en_US: Message
zh_Hans: 消息
human_description:
en_US: The message to send
zh_Hans: 要发送的消息。支持 CQ码需要同时设置 auto_escape 为 true
llm_description: The message to send
form: llm
- name: auto_escape
type: boolean
required: false
default: false
label:
en_US: Auto Escape
zh_Hans: 自动转义
human_description:
en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes.
zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。
llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending.
form: form

View File

@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = {
"seed_resize_from_w": -1,
# Samplers
# "sampler_name": "DPM++ 2M",
"sampler_name": "DPM++ 2M",
# "scheduler": "",
# "sampler_index": "Automatic",
@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def get_sample_methods(self) -> list[str]:
"""
get sample method
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:
return [d['name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool):
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
sample_methods = self.get_sample_methods()
if len(sample_methods) != 0:
parameters.append(
ToolParameter(name='sampler_name',
label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
human_description=I18nObject(
en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的Sampling method您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=sample_methods[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in sample_methods])
)
return parameters

View File

@ -144,7 +144,7 @@ class ApiTool(Tool):
path_params[parameter['name']] = value
elif parameter['in'] == 'query':
params[parameter['name']] = value
if value !='': params[parameter['name']] = value
elif parameter['in'] == 'cookie':
cookies[parameter['name']] = value

View File

@ -11,15 +11,6 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
@ -97,8 +88,9 @@ class CodeNode(BaseNode):
else:
raise ValueError(f"Output variable `{variable}` must be a string")
if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters')
if len(value) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise ValueError(f'The length of output variable `{variable}` must be'
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} characters')
return value.replace('\x00', '')
@ -115,13 +107,15 @@ class CodeNode(BaseNode):
else:
raise ValueError(f"Output variable `{variable}` must be a number")
if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.')
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise ValueError(f'Output variable `{variable}` is out of range,'
f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.')
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split('.')[1]) > MAX_PRECISION:
raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.')
if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION:
raise ValueError(f'Output variable `{variable}` has too high precision,'
f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.')
return value
@ -134,8 +128,8 @@ class CodeNode(BaseNode):
:param output_schema: output schema
:return:
"""
if depth > MAX_DEPTH:
raise ValueError("Depth limit reached, object too deep.")
if depth > dify_config.CODE_MAX_DEPTH:
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
if output_schema is None:
@ -235,9 +229,10 @@ class CodeNode(BaseNode):
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
else:
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.'
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.'
)
transformed_result[output_name] = [
@ -257,9 +252,10 @@ class CodeNode(BaseNode):
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
else:
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.'
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.'
)
transformed_result[output_name] = [
@ -279,9 +275,10 @@ class CodeNode(BaseNode):
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
else:
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.'
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.'
)
for i, value in enumerate(result[output_name]):

View File

@ -18,11 +18,6 @@ from core.workflow.nodes.http_request.entities import (
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE
MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE
class HttpExecutorResponse:
headers: dict[str, str]
@ -237,16 +232,14 @@ class HttpExecutor:
else:
raise ValueError(f'Invalid response type {type(response)}')
if executor_response.is_file:
if executor_response.size > MAX_BINARY_SIZE:
raise ValueError(
f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.'
)
else:
if executor_response.size > MAX_TEXT_SIZE:
raise ValueError(
f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.'
)
threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
if executor_response.size > threshold_size:
raise ValueError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
)
return executor_response

9
api/poetry.lock generated
View File

@ -6372,13 +6372,13 @@ semver = ["semver (>=3.0.2)"]
[[package]]
name = "pydantic-settings"
version = "2.3.4"
version = "2.4.0"
description = "Settings management using Pydantic"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"},
{file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"},
{file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"},
{file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"},
]
[package.dependencies]
@ -6386,6 +6386,7 @@ pydantic = ">=2.7.0"
python-dotenv = ">=0.21.0"
[package.extras]
azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"]
toml = ["tomli (>=2.0.1)"]
yaml = ["pyyaml (>=6.0.1)"]
@ -9633,4 +9634,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "69c20af8ecacced3cca092662223a1511acaf65cb2616a5a1e38b498223463e0"
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"

View File

@ -76,8 +76,6 @@ exclude = [
"migrations/**/*",
"services/**/*.py",
"tasks/**/*.py",
"tests/**/*.py",
"configs/**/*.py",
]
[tool.pytest_env]
@ -162,7 +160,7 @@ pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
psycopg2-binary = "~2.9.6"
pycryptodome = "3.19.1"
pydantic = "~2.8.2"
pydantic-settings = "~2.3.4"
pydantic-settings = "~2.4.0"
pydantic_extra_types = "~2.9.0"
pyjwt = "~2.8.0"
pypdfium2 = "~4.17.0"

View File

@ -22,6 +22,10 @@ class OpsService:
# decrypt_token and obfuscated_token
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')):
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
decrypt_tracing_config['project_key'] = project_key
decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
trace_config_data.tracing_config = decrypt_tracing_config
@ -37,7 +41,7 @@ class OpsService:
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map.keys() and tracing_provider != None:
if tracing_provider not in provider_config_map.keys() and tracing_provider:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
@ -51,6 +55,9 @@ class OpsService:
if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider):
return {"error": "Invalid Credentials"}
# get project key
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
# check if trace config already exists
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
@ -62,6 +69,8 @@ class OpsService:
# get tenant id
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if tracing_provider == 'langfuse':
tracing_config['project_key'] = project_key
trace_config_data = TraceAppConfig(
app_id=app_id,
tracing_provider=tracing_provider,

View File

@ -22,23 +22,20 @@ from anthropic.types import (
)
from anthropic.types.message_delta_event import Delta
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockAnthropicClass:
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Message(
id='msg-123',
type='message',
role='assistant',
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
id="msg-123",
type="message",
role="assistant",
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
model=model,
stop_reason='stop_sequence',
usage=Usage(
input_tokens=1,
output_tokens=1
)
stop_reason="stop_sequence",
usage=Usage(input_tokens=1, output_tokens=1),
)
@staticmethod
@ -46,52 +43,43 @@ class MockAnthropicClass:
full_response_text = "hello, I'm a chatbot from anthropic"
yield MessageStartEvent(
type='message_start',
type="message_start",
message=Message(
id='msg-123',
id="msg-123",
content=[],
role='assistant',
role="assistant",
model=model,
stop_reason=None,
type='message',
usage=Usage(
input_tokens=1,
output_tokens=1
)
)
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
),
)
index = 0
for i in range(0, len(full_response_text)):
yield ContentBlockDeltaEvent(
type='content_block_delta',
delta=TextDelta(text=full_response_text[i], type='text_delta'),
index=index
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
)
index += 1
yield MessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='stop_sequence'
),
usage=MessageDeltaUsage(
output_tokens=1
)
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
)
yield MessageStopEvent(type='message_stop')
yield MessageStopEvent(type="message_stop")
def mocked_anthropic(self: Messages, *,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any
) -> Union[Message, Stream[MessageStreamEvent]]:
def mocked_anthropic(
self: Messages,
*,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any,
) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key')
raise anthropic.AuthenticationError("Invalid API key")
if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
@ -102,7 +90,7 @@ class MockAnthropicClass:
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
yield

View File

@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ''
current_api_key = ""
class MockGoogleResponseClass:
_done = False
def __iter__(self):
full_response_text = 'it\'s google!'
full_response_text = "it's google!"
for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
else:
yield GenerateContentResponse(
done=False,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
class MockGoogleResponseCandidateClass:
finish_reason = 'stop'
finish_reason = "stop"
@property
def content(self) -> gag_content.Content:
return gag_content.Content(
parts=[
gag_content.Part(text='it\'s google!')
]
)
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
class MockGoogleClass:
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
@staticmethod
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
return MockGoogleResponseClass()
def generate_content(self: GenerativeModel,
def generate_content(
self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
@ -79,21 +62,21 @@ class MockGoogleClass:
global current_api_key
if len(current_api_key) < 16:
raise Exception('Invalid API key')
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
return MockGoogleClass.generate_content_sync()
@property
def generative_response_text(self) -> str:
return 'it\'s google!'
return "it's google!"
@property
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
@ -121,7 +104,8 @@ class MockGoogleClass:
if not self.default_metadata:
return client
@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
yield
monkeypatch.undo()
monkeypatch.undo()

View File

@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()

View File

@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
]
)
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
),
)
return response
@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
yield response
def text_generation(self: InferenceClient, prompt: str, *,
stream: Literal[False] = ...,
model: Optional[str] = None,
**kwargs: Any
def text_generation(
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
raise BadRequestError('Invalid API key')
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
raise BadRequestError("Invalid API key")
if model is None:
raise BadRequestError('Invalid model')
raise BadRequestError("Invalid model")
if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

View File

@ -5,10 +5,10 @@ class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
if "rerank" in model_name:
model_type = "reranker"
else:
model_type = 'embedding'
model_type = "embedding"
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@ -17,16 +17,16 @@ class MockTEIClass:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
tokens = text.split(" ")
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
"id": idx,
"text": token,
"special": False,
"start": current_index,
"stop": current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
@ -55,18 +55,18 @@ class MockTEIClass:
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
"object": "embedding",
"embedding": embedding,
"index": idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
"object": "list",
"data": embeddings,
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
"total_tokens": sum(len(text.split(" ")) for text in texts),
},
}
@ -83,9 +83,9 @@ class MockTEIClass:
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
"index": idx,
"text": text,
"score": 0.9,
}
)
# For mock, only return the first document

View File

@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
def mock_openai(
monkeypatch: MonkeyPatch,
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
"""
mock openai module
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()
unpatch()

View File

@ -43,62 +43,64 @@ class MockChatClass:
if not functions or len(functions) == 0:
return None
function: completion_create_params.Function = functions[0]
function_name = function['name']
function_description = function['description']
function_parameters = function['parameters']
function_parameters_type = function_parameters['type']
if function_parameters_type != 'object':
function_name = function["name"]
function_description = function["description"]
function_parameters = function["parameters"]
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters['properties']
function_parameters_required = function_parameters['required']
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter['type']
if parameter_type == 'string':
if 'enum' in parameter:
if len(parameter['enum']) == 0:
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter['enum'][0]
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = 'kawaii'
elif parameter_type == 'integer':
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == 'number':
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == 'boolean':
elif parameter_type == "boolean":
parameters[parameter_name] = True
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool = tools[0]
if 'type' in tools and tools['type'] != 'function':
if "type" in tools and tools["type"] != "function":
return None
function = tool['function']
function = tool["function"]
function_call = MockChatClass.generate_function_call(functions=[function])
if function_call is None:
return None
list_tool_calls.append(ChatCompletionMessageToolCall(
id='sakurajima-mai',
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type='function'
))
list_tool_calls.append(
ChatCompletionMessageToolCall(
id="sakurajima-mai",
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type="function",
)
)
return list_tool_calls
@staticmethod
def mocked_openai_chat_create_sync(
model: str,
@ -111,30 +113,27 @@ class MockChatClass:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
return _ChatCompletion(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
_ChatCompletionChoice(
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
message=ChatCompletionMessage(
content='elaina',
role='assistant',
function_call=function_call,
tool_calls=tool_calls
)
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
),
)
],
created=int(time()),
model=model,
object='chat.completion',
system_fingerprint='',
object="chat.completion",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_chat_create_stream(
model: str,
@ -150,36 +149,40 @@ class MockChatClass:
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content='',
content="",
function_call=ChoiceDeltaFunctionCall(
name=function_call.name,
arguments=function_call.arguments,
) if function_call else None,
role='assistant',
)
if function_call
else None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id='misaka-mikoto',
id="misaka-mikoto",
function=ChoiceDeltaToolCallFunction(
name=tool_calls[0].function.name,
arguments=tool_calls[0].function.arguments,
),
type='function'
type="function",
)
] if tool_calls and len(tool_calls) > 0 else None
]
if tool_calls and len(tool_calls) > 0
else None,
),
finish_reason='function_call',
finish_reason="function_call",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
@ -188,30 +191,45 @@ class MockChatClass:
)
else:
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content=full_text[i],
role='assistant',
role="assistant",
),
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
)
def chat_create(self: Completions, *,
def chat_create(
self: Completions,
*,
messages: list[ChatCompletionMessageParam],
model: Union[str,Literal[
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
model: Union[
str,
Literal[
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
],
],
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
@ -220,24 +238,32 @@ class MockChatClass:
**kwargs: Any,
):
openai_models = [
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]
azure_openai_models = [
"gpt35", "gpt-4v", "gpt-35-turbo"
]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if stream:
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

View File

@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockCompletionsClass:
@staticmethod
def mocked_openai_completion_create_sync(
model: str
) -> CompletionMessage:
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
return CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
@ -38,13 +36,11 @@ class MockCompletionsClass:
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_completion_create_stream(
model: str
) -> Generator[CompletionMessage, None, None]:
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
@ -76,46 +72,59 @@ class MockCompletionsClass:
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text=full_text[i],
index=0,
logprobs=None,
finish_reason="content_filter"
)
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
],
)
def completion_create(self: Completions, *, model: Union[
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
"text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001",
"text-ada-001"],
def completion_create(
self: Completions,
*,
model: Union[
str,
Literal[
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
],
],
prompt: Union[str, list[str], list[int], list[list[int]], None],
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
):
openai_models = [
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
]
azure_openai_models = [
"gpt-35-turbo-instruct"
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
]
azure_openai_models = ["gpt-35-turbo-instruct"]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if not prompt:
raise BadRequestError('Invalid prompt')
raise BadRequestError("Invalid prompt")
if stream:
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

File diff suppressed because one or more lines are too long

View File

@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockModerationClass:
def moderation_create(self: Moderations,*,
def moderation_create(
self: Moderations,
*,
input: Union[str, list[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
raise InvokeAuthorizationError("Invalid API key")
for text in input:
result = []
if 'kill' in text:
if "kill" in text:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
"harassment": 1.0,
"harassment/threatening": 1.0,
"hate": 1.0,
"hate/threatening": 1.0,
"self-harm": 1.0,
"self-harm/instructions": 1.0,
"self-harm/intent": 1.0,
"sexual": 1.0,
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
}
result.append(Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
else:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
"harassment": 0.0,
"harassment/threatening": 0.0,
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"self-harm/instructions": 0.0,
"self-harm/intent": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
}
result.append(Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
return ModerationCreateResponse(
id='shiroii kuloko',
model=model,
results=result
)
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

View File

@ -6,17 +6,18 @@ from openai.types.model import Model
class MockModelClass:
"""
mock class for openai.models.Models
mock class for openai.models.Models
"""
def list(
self,
**kwargs,
) -> list[Model]:
return [
Model(
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
created=int(time()),
object='model',
owned_by='organization:org-123',
object="model",
owned_by="organization:org-123",
)
]
]

View File

@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockSpeech2TextClass:
def speech2text_create(self: Transcriptions,
def speech2text_create(
self: Transcriptions,
*,
file: FileTypes,
model: Union[str, Literal["whisper-1"]],
@ -17,14 +18,12 @@ class MockSpeech2TextClass:
prompt: str | NotGiven = NOT_GIVEN,
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> Transcription:
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
return Transcription(
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
)
raise InvokeAuthorizationError("Invalid API key")
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")

View File

@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
class MockXinferenceClass:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
def get_chat_model(
self: Client, model_uid: str
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")
if "generate" == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid:
if "chat" == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid:
if "embedding" == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid:
if "rerank" == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found')
raise RuntimeError("404 Not Found")
def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
if "v1/models/" in url:
# get model uid
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
model_uid = url.split("/")[-1] or ""
if not re.match(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
if model_uid in ['generate', 'chat']:
if model_uid in ["generate", "chat"]:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
@ -75,12 +78,12 @@ class MockXinferenceClass:
"revision": null,
"context_length": 2048,
"replica": 1
}'''
}"""
return response
elif model_uid == 'embedding':
elif model_uid == "embedding":
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
@ -93,51 +96,48 @@ class MockXinferenceClass:
],
"revision": null,
"max_tokens": 512
}'''
}"""
return response
elif 'v1/cluster/auth' in url:
elif "v1/cluster/auth" in url:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"auth": true
}'''
}"""
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
def rerank(
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank':
raise RuntimeError('404 Not Found')
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "rerank"
):
raise RuntimeError("404 Not Found")
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
raise RuntimeError("404 Not Found")
if top_n is None:
top_n = 1
return {
'results': [
{
'index': i,
'document': doc,
'relevance_score': 0.9
}
for i, doc in enumerate(documents[:top_n])
"results": [
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(
self: RESTfulGenerateModelHandle,
input: Union[str, list[str]],
**kwargs
) -> dict:
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'embedding':
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "embedding"
):
raise RuntimeError("404 Not Found")
if isinstance(input, str):
input = [input]
@ -147,32 +147,27 @@ class MockXinferenceClass:
object="list",
model=self._model_uid,
data=[
EmbeddingData(
index=i,
object="embedding",
embedding=[1919.810 for _ in range(768)]
)
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
for i in range(ipt_len)
],
usage=EmbeddingUsage(
prompt_tokens=ipt_len,
total_tokens=ipt_len
)
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
)
return embedding
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()

View File

@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
)
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1.2',
model="claude-instant-1.2",
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens': 10
},
stop=['How'],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_stream_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -98,18 +79,14 @@ def test_get_num_tokens():
model = AnthropicLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_anthropic_mock):
provider = AnthropicProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})

File diff suppressed because one or more lines are too long

View File

@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': 'invalid_key',
'base_model_name': 'text-embedding-ada-002'
}
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": "invalid_key",
"base_model_name": "text-embedding-ada-002",
},
)
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
}
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
result = model.invoke(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -58,14 +56,7 @@ def test_get_num_tokens():
model = AzureOpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embedding',
credentials={
'base_model_name': 'text-embedding-ada-002'
},
texts=[
"hello",
"world"
]
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
)
assert num_tokens == 2

View File

@ -17,111 +17,99 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = BaichuanLarguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan2-turbo',
credentials={
'api_key': 'invalid_key',
'secret_key': 'invalid_key'
}
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
}
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
)
def test_invoke_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_with_system_message():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='请记住你是Kasumi。'
),
UserPromptMessage(
content='现在告诉我你是谁?'
)
SystemPromptMessage(content="请记住你是Kasumi。"),
UserPromptMessage(content="现在告诉我你是谁?"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -131,34 +119,31 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'with_search_enhance': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"with_search_enhance": True,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
@ -166,25 +151,22 @@ def test_invoke_with_search():
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
total_message += chunk.delta.message.content
assert '' not in total_message
assert "" not in total_message
def test_get_num_tokens():
sleep(3)
model = BaichuanLarguageModel()
response = model.get_num_tokens(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 9
assert response == 9

View File

@ -10,14 +10,6 @@ def test_validate_provider_credentials():
provider = BaichuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = BaichuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
)
@ -30,44 +22,40 @@ def test_invoke_model():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = BaichuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
@ -92,8 +80,8 @@ def test_max_chunks():
"world",
"hello",
"world",
]
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22
assert len(result.embeddings) == 22

View File

@ -13,77 +13,63 @@ def test_validate_credentials():
model = BedrockLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
}
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
)
def test_invoke_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens_to_sample': 10
},
stop=['How'],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens_to_sample': 100
},
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -100,20 +86,18 @@ def test_get_num_tokens():
model = BedrockLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='meta.llama2-13b-chat-v1',
credentials = {
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@ -10,14 +10,12 @@ def test_validate_provider_credentials():
provider = BedrockProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
}
)

View File

@ -23,79 +23,64 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': 'invalid_key'
}
)
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock):
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
),
UserPromptMessage(
content='波士顿天气如何?'
)
UserPromptMessage(content="波士顿天气如何?"),
],
model_parameters={
'temperature': 0,
'top_p': 1.0,
"temperature": 0,
"top_p": 1.0,
},
stop=['you'],
user='abc-123',
stop=["you"],
user="abc-123",
stream=True,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(response, Generator)
call: LLMResultChunk = None
chunks = []
@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock):
break
assert call is not None
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
UserPromptMessage(
content='What is the weather like in San Francisco?'
)
],
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
user='abc-123',
stop=["you"],
user="abc-123",
stream=False,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
assert response.message.tool_calls[0].function.name == 'get_current_weather'
assert response.message.tool_calls[0].function.name == "get_current_weather"
def test_get_num_tokens():
model = ChatGLMLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21
assert num_tokens == 21

View File

@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = ChatGLMProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_base': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})

View File

@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_validate_credentials_for_completion_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_completion_model():
model = CohereLargeLanguageModel()
credentials = {
'api_key': os.environ.get('COHERE_API_KEY')
}
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
result = model.invoke(
model='command-light',
model="command-light",
credentials=credentials,
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 1
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 1},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
def test_invoke_stream_completion_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(result, Generator)
@ -109,28 +71,24 @@ def test_invoke_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'p': 0.99,
'presence_penalty': 0.0,
'frequency_penalty': 0.0,
'max_tokens': 10
"temperature": 0.0,
"p": 0.99,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
@ -141,24 +99,17 @@ def test_invoke_stream_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(result, Generator)
@ -177,32 +128,22 @@ def test_get_num_tokens():
model = CohereLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 3
num_tokens = model.get_num_tokens(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 15
@ -213,25 +154,17 @@ def test_fine_tuned_model():
# test invoke
result = model.invoke(
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'completion'
},
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
@ -242,25 +175,17 @@ def test_fine_tuned_chat_model():
# test invoke
result = model.invoke(
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'chat'
},
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)

View File

@ -10,12 +10,6 @@ def test_validate_provider_credentials():
provider = CohereProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})

View File

@ -11,29 +11,17 @@ def test_validate_credentials():
model = CohereRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_model():
model = CohereRerankModel()
result = model.invoke(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="rerank-english-v2.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
@ -41,9 +29,9 @@ def test_invoke_model():
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
"is the capital of the United States. It is a federal district. The President of the USA and many major "
"national government offices are in the territory. This makes it the political center of the United "
"States of America."
"States of America.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = CohereTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
)
@ -30,17 +22,10 @@ def test_invoke_model():
model = CohereTextEmbeddingModel()
result = model.invoke(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -52,14 +37,9 @@ def test_get_num_tokens():
model = CohereTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world"
]
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world"],
)
assert num_tokens == 3

File diff suppressed because one or more lines are too long

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_google_mock):
provider = GoogleProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})

View File

@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
model="HuggingFaceH4/zephyr-7b-beta",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='fake-model',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
model="fake-model",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
model.validate_credentials(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
model.validate_credentials(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -286,18 +264,14 @@ def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 7

View File

@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key',
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": "invalid_key",
},
)
model.validate_credentials(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
model.validate_credentials(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
@ -104,18 +98,15 @@ def test_get_num_tokens():
model = HuggingfaceHubTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
# model name is only used in mock
model_name = 'embedding'
model_name = "embedding"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='reranker',
model="reranker",
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
model_name = 'embedding'
model_name = "embedding"
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)

View File

@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
model_name = "reranker"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
model_name = "reranker"
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
"and she leads a team named PopiParty.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -14,19 +14,15 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='hunyuan-standard',
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
}
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
@ -34,23 +30,16 @@ def test_invoke_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -61,23 +50,15 @@ def test_invoke_stream_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -93,19 +74,17 @@ def test_get_num_tokens():
model = HunyuanLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14

View File

@ -10,16 +10,11 @@ def test_validate_provider_credentials():
provider = HunyuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
)
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
provider.validate_provider_credentials(
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
}
)

View File

@ -12,19 +12,15 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='hunyuan-embedding',
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
}
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
@ -32,47 +28,43 @@ def test_invoke_model():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = HunyuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
@ -97,8 +89,8 @@ def test_max_chunks():
"world",
"hello",
"world",
]
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22
assert len(result.embeddings) == 22

View File

@ -10,14 +10,6 @@ def test_validate_provider_credentials():
provider = JinaProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = JinaTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
)
@ -30,15 +22,12 @@ def test_invoke_model():
model = JinaTextEmbeddingModel()
result = model.invoke(
model='jina-embeddings-v2-base-en',
model="jina-embeddings-v2-base-en",
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -50,14 +39,11 @@ def test_get_num_tokens():
model = JinaTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='jina-embeddings-v2-base-en',
model="jina-embeddings-v2-base-en",
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 6

View File

@ -1,4 +1,4 @@
"""
LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now.
"""
LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now.
"""

View File

@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -123,28 +102,21 @@ def test_invoke_stream_completion_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_stream_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -154,64 +126,48 @@ def test_invoke_stream_chat_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = LocalAILanguageModel()
num_tokens = model.get_num_tokens(
model='????',
model="????",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='????',
model="????",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert isinstance(num_tokens, int)

View File

@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
model="bge-reranker-v2-m3",
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model='bge-reranker-base',
model="bge-reranker-base",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_rerank_model():
model = LocalaiRerankModel()
response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
@ -45,43 +44,38 @@ def test_invoke_rerank_model():
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 3
def test__invoke():
model = LocalaiRerankModel()
# Test case 1: Empty docs
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0
# Test case 2: Valid invocation
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
@ -91,12 +85,12 @@ def test__invoke():
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

View File

@ -10,19 +10,9 @@ def test_validate_credentials():
model = LocalAISpeech2text()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': 'invalid_url'
}
)
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
}
)
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
def test_invoke_model():
@ -32,23 +22,21 @@ def test_invoke_model():
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
model="whisper-1",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
file=file,
user="abc-123"
user="abc-123",
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -12,54 +12,47 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embo-01',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
model="embo-01",
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
)
model.validate_credentials(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
model = MinimaxTextEmbeddingModel()
result = model.invoke(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 16
def test_get_num_tokens():
model = MinimaxTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -17,79 +17,70 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = MinimaxLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='abab5.5-chat',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': 'invalid_key'
}
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
)
model.validate_credentials(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5-chat',
model="abab5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -99,34 +90,31 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
@ -134,25 +122,22 @@ def test_invoke_with_search():
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
assert '参考资料' in total_message
assert "参考资料" in total_message
def test_get_num_tokens():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.get_num_tokens(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 30
assert response == 30

View File

@ -12,14 +12,14 @@ def test_validate_provider_credentials():
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'minimax_api_key': 'hahahaha',
'minimax_group_id': '123',
"minimax_api_key": "hahahaha",
"minimax_group_id": "123",
}
)
provider.validate_provider_credentials(
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
}
)

Some files were not shown because too many files have changed in this diff Show More