merge main

This commit is contained in:
StyleZhang 2024-07-11 10:20:05 +08:00
commit 009cb2a650
708 changed files with 16834 additions and 5596 deletions

View File

@ -1,6 +1,7 @@
#!/bin/bash
cd web && npm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc

7
.gitattributes vendored Normal file
View File

@ -0,0 +1,7 @@
# Ensure that .sh scripts use LF as line separator, even if they are checked out
# to Windows(NTFS) file-system, by a user of Docker for Window.
# These .sh scripts will be run from the Container after `docker compose up -d`.
# If they appear to be CRLF style, Dash from the Container will fail to execute
# them.
*.sh text eol=lf

View File

@ -1,13 +1,21 @@
# Checklist:
> [!IMPORTANT]
> Please review the checklist below before submitting your pull request.
- [ ] Please open an issue before creating a PR or link to an existing issue
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
# Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. Close issue syntax: `Fixes #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
Fixes # (issue)
Fixes
## Type of Change
Please delete options that are not relevant.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
@ -15,18 +23,12 @@ Please delete options that are not relevant.
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
- [ ] Dependency upgrade
# How Has This Been Tested?
# Testing Instructions
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
- [ ] TODO
- [ ] Test A
- [ ] Test B
# Suggested Checklist:
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] My changes generate no new warnings
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [ ] `optional` I have made corresponding changes to the documentation
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
- [ ] `optional` New and existing unit tests pass locally with my changes

View File

@ -48,18 +48,18 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5

2
.gitignore vendored
View File

@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
pyrightconfig.json
.idea/

1
.vscode/launch.json vendored
View File

@ -13,7 +13,6 @@
"jinja": true,
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"args": [

View File

@ -192,6 +192,11 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Using Terraform for Deployment
##### Azure Global
Deploy Dify to Azure with a single click using [terraform](https://www.terraform.io/).
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Contributing

View File

@ -175,6 +175,12 @@ docker compose up -d
- [رسم بياني Helm من قبل @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### استخدام Terraform للتوزيع
##### Azure Global
استخدم [terraform](https://www.terraform.io/) لنشر Dify على Azure بنقرة واحدة.
- [Azure Terraform بواسطة @nikawang](https://github.com/nikawang/dify-azure-terraform)
## المساهمة

View File

@ -197,6 +197,12 @@ docker compose up -d
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### 使用 Terraform 部署
##### Azure Global
使用 [terraform](https://www.terraform.io/) 一键部署 Dify 到 Azure。
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)

View File

@ -199,6 +199,12 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop
- [Gráfico Helm por @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Uso de Terraform para el despliegue
##### Azure Global
Utiliza [terraform](https://www.terraform.io/) para desplegar Dify en Azure con un solo clic.
- [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Contribuir

View File

@ -197,6 +197,12 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau
- [Helm Chart par @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Utilisation de Terraform pour le déploiement
##### Azure Global
Utilisez [terraform](https://www.terraform.io/) pour déployer Dify sur Azure en un clic.
- [Azure Terraform par @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Contribuer

View File

@ -196,6 +196,12 @@ docker compose up -d
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Terraformを使用したデプロイ
##### Azure Global
[terraform](https://www.terraform.io/) を使用して、AzureにDifyをワンクリックでデプロイします。
- [nikawangのAzure Terraform](https://github.com/nikawang/dify-azure-terraform)
## 貢献

View File

@ -197,6 +197,13 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Terraform atorlugu pilersitsineq
##### Azure Global
Atoruk [terraform](https://www.terraform.io/) Dify-mik Azure-mut ataatsikkut ikkussuilluarlugu.
- [Azure Terraform atorlugu @nikawang](https://github.com/nikawang/dify-azure-terraform)
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

View File

@ -190,6 +190,12 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
#### Terraform을 사용한 배포
##### Azure Global
[terraform](https://www.terraform.io/)을 사용하여 Azure에 Dify를 원클릭으로 배포하세요.
- [nikawang의 Azure Terraform](https://github.com/nikawang/dify-azure-terraform)
## 기여
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.

View File

@ -72,6 +72,13 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@ -144,6 +151,16 @@ CHROMA_DATABASE=default_database
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
CHROMA_AUTH_CREDENTIALS=difyai123456
# AnalyticDB configuration
ANALYTICDB_KEY_ID=your-ak
ANALYTICDB_KEY_SECRET=your-sk
ANALYTICDB_REGION_ID=cn-hangzhou
ANALYTICDB_INSTANCE_ID=gp-ab123456
ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1
OPENSEARCH_PORT=9200
@ -230,4 +247,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0

View File

@ -5,8 +5,7 @@ WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.3
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir --upgrade poetry==${POETRY_VERSION}
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache

View File

@ -11,6 +11,7 @@
```bash
cd ../docker
cp middleware.env.example middleware.env
docker compose -f docker-compose.middleware.yaml -p dify up -d
cd ../api
```

View File

@ -1,8 +1,8 @@
import os
from configs.app_config import DifyConfig
from configs import dify_config
if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
if os.environ.get("DEBUG", "false").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@ -43,6 +43,8 @@ from extensions import (
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
# TODO: Find a way to avoid importing models here
from models import account, dataset, model, source, task, tool, tools, web
from services.account_service import AccountService
@ -81,7 +83,7 @@ def create_flask_app_with_configs() -> Flask:
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(DifyConfig().model_dump())
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():

View File

@ -8,6 +8,7 @@ import click
from flask import current_app
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
@ -112,7 +113,7 @@ def reset_encrypt_key_pair():
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if current_app.config['EDITION'] != 'SELF_HOSTED':
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
return
@ -336,6 +337,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -0,0 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()

View File

@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig
from configs.enterprise import EnterpriseFeatureConfig
@ -9,9 +10,6 @@ from configs.packaging import PackagingInfo
class DifyConfig(
# based on pydantic-settings
BaseSettings,
# Packaging info
PackagingInfo,
@ -31,12 +29,39 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra='ignore',
)
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class DeploymentConfig(BaseModel):
class DeploymentConfig(BaseSettings):
"""
Deployment configs
"""

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class EnterpriseFeatureConfig(BaseModel):
class EnterpriseFeatureConfig(BaseSettings):
"""
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**

View File

@ -1,5 +1,3 @@
from pydantic import BaseModel
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class NotionConfig(BaseModel):
class NotionConfig(BaseSettings):
"""
Notion integration configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeFloat
from pydantic import Field, NonNegativeFloat
from pydantic_settings import BaseSettings
class SentryConfig(BaseModel):
class SentryConfig(BaseSettings):
"""
Sentry configs
"""

View File

@ -1,11 +1,12 @@
from typing import Optional
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
class SecurityConfig(BaseModel):
class SecurityConfig(BaseSettings):
"""
Secret Key configs
"""
@ -17,8 +18,12 @@ class SecurityConfig(BaseModel):
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
default=24,
)
class AppExecutionConfig(BaseModel):
class AppExecutionConfig(BaseSettings):
"""
App Execution configs
"""
@ -26,9 +31,13 @@ class AppExecutionConfig(BaseModel):
description='execution timeout in seconds for app execution',
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
default=0,
)
class CodeExecutionSandboxConfig(BaseModel):
class CodeExecutionSandboxConfig(BaseSettings):
"""
Code Execution Sandbox configs
"""
@ -43,36 +52,36 @@ class CodeExecutionSandboxConfig(BaseModel):
)
class EndpointConfig(BaseModel):
class EndpointConfig(BaseSettings):
"""
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description='The backend URL prefix of the console API.'
'used to concatenate the login authorization callback or notion integration callback.',
default='https://cloud.dify.ai',
default='',
)
CONSOLE_WEB_URL: str = Field(
description='The front-end URL prefix of the console web.'
'used to concatenate some front-end addresses and for CORS configuration use.',
default='https://cloud.dify.ai',
default='',
)
SERVICE_API_URL: str = Field(
description='Service API Url prefix.'
'used to display Service API Base Url to the front-end.',
default='https://api.dify.ai',
default='',
)
APP_WEB_URL: str = Field(
description='WebApp Url prefix.'
'used to display WebAPP API Base Url to the front-end.',
default='https://udify.app',
default='',
)
class FileAccessConfig(BaseModel):
class FileAccessConfig(BaseSettings):
"""
File Access configs
"""
@ -82,7 +91,7 @@ class FileAccessConfig(BaseModel):
'Url is signed and has expiration time.',
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
alias_priority=1,
default='https://cloud.dify.ai',
default='',
)
FILES_ACCESS_TIMEOUT: int = Field(
@ -91,7 +100,7 @@ class FileAccessConfig(BaseModel):
)
class FileUploadConfig(BaseModel):
class FileUploadConfig(BaseSettings):
"""
File Uploading configs
"""
@ -116,7 +125,7 @@ class FileUploadConfig(BaseModel):
)
class HttpConfig(BaseModel):
class HttpConfig(BaseSettings):
"""
HTTP configs
"""
@ -136,7 +145,7 @@ class HttpConfig(BaseModel):
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
inner_WEB_API_CORS_ALLOW_ORIGINS: Optional[str] = Field(
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default='*',
@ -148,7 +157,7 @@ class HttpConfig(BaseModel):
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
class InnerAPIConfig(BaseModel):
class InnerAPIConfig(BaseSettings):
"""
Inner API configs
"""
@ -163,7 +172,7 @@ class InnerAPIConfig(BaseModel):
)
class LoggingConfig(BaseModel):
class LoggingConfig(BaseSettings):
"""
Logging configs
"""
@ -195,7 +204,7 @@ class LoggingConfig(BaseModel):
)
class ModelLoadBalanceConfig(BaseModel):
class ModelLoadBalanceConfig(BaseSettings):
"""
Model load balance configs
"""
@ -205,7 +214,7 @@ class ModelLoadBalanceConfig(BaseModel):
)
class BillingConfig(BaseModel):
class BillingConfig(BaseSettings):
"""
Platform Billing Configurations
"""
@ -215,7 +224,7 @@ class BillingConfig(BaseModel):
)
class UpdateConfig(BaseModel):
class UpdateConfig(BaseSettings):
"""
Update configs
"""
@ -225,7 +234,7 @@ class UpdateConfig(BaseModel):
)
class WorkflowConfig(BaseModel):
class WorkflowConfig(BaseSettings):
"""
Workflow feature configs
"""
@ -246,7 +255,7 @@ class WorkflowConfig(BaseModel):
)
class OAuthConfig(BaseModel):
class OAuthConfig(BaseSettings):
"""
oauth configs
"""
@ -276,7 +285,7 @@ class OAuthConfig(BaseModel):
)
class ModerationConfig(BaseModel):
class ModerationConfig(BaseSettings):
"""
Moderation in app configs.
"""
@ -288,7 +297,7 @@ class ModerationConfig(BaseModel):
)
class ToolConfig(BaseModel):
class ToolConfig(BaseSettings):
"""
Tool configs
"""
@ -299,7 +308,7 @@ class ToolConfig(BaseModel):
)
class MailConfig(BaseModel):
class MailConfig(BaseSettings):
"""
Mail Configurations
"""
@ -355,7 +364,7 @@ class MailConfig(BaseModel):
)
class RagEtlConfig(BaseModel):
class RagEtlConfig(BaseSettings):
"""
RAG ETL Configurations.
"""
@ -381,7 +390,7 @@ class RagEtlConfig(BaseModel):
)
class DataSetConfig(BaseModel):
class DataSetConfig(BaseSettings):
"""
Dataset configs
"""
@ -391,8 +400,13 @@ class DataSetConfig(BaseModel):
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
default=False,
)
class WorkspaceConfig(BaseModel):
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
"""
@ -403,7 +417,7 @@ class WorkspaceConfig(BaseModel):
)
class IndexingConfig(BaseModel):
class IndexingConfig(BaseSettings):
"""
Indexing configs.
"""
@ -414,7 +428,7 @@ class IndexingConfig(BaseModel):
)
class ImageFormatConfig(BaseModel):
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description='multi model send image format, support base64, url, default is base64',
default='base64',

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt
from pydantic import Field, NonNegativeInt
from pydantic_settings import BaseSettings
class HostedOpenAiConfig(BaseModel):
class HostedOpenAiConfig(BaseSettings):
"""
Hosted OpenAI service config
"""
@ -68,7 +69,7 @@ class HostedOpenAiConfig(BaseModel):
)
class HostedAzureOpenAiConfig(BaseModel):
class HostedAzureOpenAiConfig(BaseSettings):
"""
Hosted OpenAI service config
"""
@ -94,7 +95,7 @@ class HostedAzureOpenAiConfig(BaseModel):
)
class HostedAnthropicConfig(BaseModel):
class HostedAnthropicConfig(BaseSettings):
"""
Hosted Azure OpenAI service config
"""
@ -125,7 +126,7 @@ class HostedAnthropicConfig(BaseModel):
)
class HostedMinmaxConfig(BaseModel):
class HostedMinmaxConfig(BaseSettings):
"""
Hosted Minmax service config
"""
@ -136,7 +137,7 @@ class HostedMinmaxConfig(BaseModel):
)
class HostedSparkConfig(BaseModel):
class HostedSparkConfig(BaseSettings):
"""
Hosted Spark service config
"""
@ -147,7 +148,7 @@ class HostedSparkConfig(BaseModel):
)
class HostedZhipuAIConfig(BaseModel):
class HostedZhipuAIConfig(BaseSettings):
"""
Hosted Minmax service config
"""
@ -158,7 +159,7 @@ class HostedZhipuAIConfig(BaseModel):
)
class HostedModerationConfig(BaseModel):
class HostedModerationConfig(BaseSettings):
"""
Hosted Moderation service config
"""
@ -174,7 +175,7 @@ class HostedModerationConfig(BaseModel):
)
class HostedFetchAppTemplateConfig(BaseModel):
class HostedFetchAppTemplateConfig(BaseSettings):
"""
Hosted Moderation service config
"""

View File

@ -1,13 +1,16 @@
from typing import Any, Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@ -21,7 +24,7 @@ from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseModel):
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description='storage type,'
' default to `local`,'
@ -35,14 +38,14 @@ class StorageConfig(BaseModel):
)
class VectorStoreConfig(BaseModel):
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description='vector store type',
default=None,
)
class KeywordStoreConfig(BaseModel):
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description='keyword store type',
default='jieba',
@ -80,6 +83,11 @@ class DatabaseConfig:
default='',
)
DB_EXTRAS: str = Field(
description='db extras options. Example: keepalives_idle=60&keepalives=1',
default='',
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description='db uri scheme',
default='postgresql',
@ -88,7 +96,12 @@ class DatabaseConfig:
@computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = f"?client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else ""
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}")
@ -113,7 +126,7 @@ class DatabaseConfig:
default=False,
)
SQLALCHEMY_ECHO: bool = Field(
SQLALCHEMY_ECHO: bool | str = Field(
description='whether to enable SqlAlchemy echo',
default=False,
)
@ -143,7 +156,7 @@ class CeleryConfig(DatabaseConfig):
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str:
def CELERY_RESULT_BACKEND(self) -> str | None:
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
@ -167,9 +180,11 @@ class MiddlewareConfig(
GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
MilvusConfig,
OpenSearchConfig,

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class RedisConfig(BaseModel):
class RedisConfig(BaseSettings):
"""
Redis configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class AliyunOSSStorageConfig(BaseModel):
class AliyunOSSStorageConfig(BaseSettings):
"""
Aliyun storage configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class S3StorageConfig(BaseModel):
class S3StorageConfig(BaseSettings):
"""
S3 storage configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class AzureBlobStorageConfig(BaseModel):
class AzureBlobStorageConfig(BaseSettings):
"""
Azure Blob storage configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class GoogleCloudStorageConfig(BaseModel):
class GoogleCloudStorageConfig(BaseSettings):
"""
Google Cloud storage configs
"""

View File

@ -0,0 +1,36 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class OCIStorageConfig(BaseSettings):
"""
OCI storage configs
"""
OCI_ENDPOINT: Optional[str] = Field(
description='OCI storage endpoint',
default=None,
)
OCI_REGION: Optional[str] = Field(
description='OCI storage region',
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description='OCI storage bucket name',
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description='OCI storage access key',
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description='OCI storage secret key',
default=None,
)

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class TencentCloudCOSStorageConfig(BaseModel):
class TencentCloudCOSStorageConfig(BaseSettings):
"""
Tencent Cloud COS storage configs
"""

View File

@ -0,0 +1,44 @@
from typing import Optional
from pydantic import BaseModel, Field
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to AnalyticDB.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
)
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ChromaConfig(BaseModel):
class ChromaConfig(BaseSettings):
"""
Chroma configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MilvusConfig(BaseModel):
class MilvusConfig(BaseSettings):
"""
Milvus configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OpenSearchConfig(BaseModel):
class OpenSearchConfig(BaseSettings):
"""
OpenSearch configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OracleConfig(BaseModel):
class OracleConfig(BaseSettings):
"""
ORACLE configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectorConfig(BaseModel):
class PGVectorConfig(BaseSettings):
"""
PGVector configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectoRSConfig(BaseModel):
class PGVectoRSConfig(BaseSettings):
"""
PGVectoRS configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class QdrantConfig(BaseModel):
class QdrantConfig(BaseSettings):
"""
Qdrant configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class RelytConfig(BaseModel):
class RelytConfig(BaseSettings):
"""
Relyt configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TencentVectorDBConfig(BaseModel):
class TencentVectorDBConfig(BaseSettings):
"""
Tencent Vector configs
"""
@ -24,7 +25,7 @@ class TencentVectorDBConfig(BaseModel):
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description='Tencent Vector password',
description='Tencent Vector username',
default=None,
)
@ -38,7 +39,12 @@ class TencentVectorDBConfig(BaseModel):
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: PositiveInt = Field(
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description='Tencent Vector replicas',
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description='Tencent Vector Database',
default=None,
)

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class TiDBVectorConfig(BaseModel):
class TiDBVectorConfig(BaseSettings):
"""
TiDB Vector configs
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class WeaviateConfig(BaseModel):
class WeaviateConfig(BaseSettings):
"""
Weaviate configs
"""

View File

@ -1,14 +1,15 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class PackagingInfo(BaseModel):
class PackagingInfo(BaseSettings):
"""
Packaging build information
"""
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.6.12-fix1',
default='0.6.13',
)
COMMIT_SHA: str = Field(

View File

@ -14,7 +14,7 @@ language_timezone_mapping = {
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata'
'hi-IN': 'Asia/Kolkata',
}
languages = list(language_timezone_mapping.keys())

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,4 @@
TTS_AUTO_PLAY_TIMEOUT = 5
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02

View File

@ -30,7 +30,7 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
# Import billing controllers
from .billing import billing

View File

@ -134,6 +134,7 @@ class AppApi(Resource):
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
args = parser.parse_args()
app_service = AppService()
@ -190,6 +191,10 @@ class AppExportApi(Resource):
@get_app_model
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
return {

View File

@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
@get_app_model
def post(self, app_model):
from werkzeug.exceptions import InternalServerError
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form['voice'],
streaming=False
text=text,
message_id=message_id,
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()

View File

@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
)
return helper.compact_generate_response(response)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -6,6 +6,7 @@ from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@ -16,11 +17,11 @@ from ..wraps import account_initialization_required
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
client_secret=current_app.config.get(
'NOTION_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
if not dify_config.NOTION_CLIENT_ID or not dify_config.NOTION_CLIENT_SECRET:
return {}
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
OAUTH_PROVIDERS = {
'notion': notion_oauth
@ -39,8 +40,10 @@ class OAuthDataSource(Resource):
print(vars(oauth_provider))
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
internal_secret = dify_config.NOTION_INTERNAL_SECRET
if not internal_secret:
return {'error': 'Internal secret is not set'},
oauth_provider.save_internal_access_token(internal_secret)
return { 'data': '' }
else:
@ -60,13 +63,13 @@ class OAuthDataSourceCallback(Resource):
if 'code' in request.args:
code = request.args.get('code')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):

View File

@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500
class InvalidEmailError(BaseHTTPException):
error_code = 'invalid_email'
description = "The email address is not valid."
code = 400
class PasswordMismatchError(BaseHTTPException):
error_code = 'password_mismatch'
description = "The passwords do not match."
code = 400
class InvalidTokenError(BaseHTTPException):
error_code = 'invalid_or_expired_token'
description = "The token is invalid or has expired."
code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = 'password_reset_rate_limit_exceeded'
description = "Password reset rate limit exceeded. Try again later."
code = 429

View File

@ -0,0 +1,107 @@
import base64
import logging
import secrets
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.auth.error import (
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
PasswordResetRateLimitExceededError,
)
from controllers.console.setup import setup_required
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
args = parser.parse_args()
email = args['email']
if not email_validate(email):
raise InvalidEmailError()
account = Account.query.filter_by(email=email).first()
if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else:
# Return success to avoid revealing email registration status
logging.warning(f"Attempt to reset password for unregistered email: {email}")
return {"result": "success"}
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
return {'is_valid': False, 'email': None}
return {'is_valid': True, 'email': reset_data.get('email')}
class ForgotPasswordResetApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
args = parser.parse_args()
new_password = args['new_password']
password_confirm = args['password_confirm']
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
raise InvalidTokenError()
AccountService.revoke_reset_password_token(token)
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get('email')).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
return {'result': 'success'}
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')

View File

@ -1,7 +1,7 @@
from typing import cast
import flask_login
from flask import current_app, request
from flask import request
from flask_restful import Resource, reqparse
import services
@ -56,14 +56,14 @@ class LogoutApi(Resource):
class ResetPasswordApi(Resource):
@setup_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=email, required=True, location='json')
args = parser.parse_args()
# parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json')
# args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
account = {'email': args['email']}
# account = {'email': args['email']}
# account = AccountService.get_by_email(args['email'])
# if account is None:
# raise ValueError('Email not found')
@ -71,22 +71,22 @@ class ResetPasswordApi(Resource):
# AccountService.update_password(account, new_password)
# todo: Send email
MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
# MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
message = {
'from_email': 'noreply@example.com',
'to': [{'email': account.email}],
'subject': 'Reset your Dify password',
'html': """
<p>Dear User,</p>
<p>The Dify team has generated a new password for you, details as follows:</p>
<p><strong>{new_password}</strong></p>
<p>Please change your password to log in as soon as possible.</p>
<p>Regards,</p>
<p>The Dify Team</p>
"""
}
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
# response = mailchimp.messages.send({
# 'message': message,

View File

@ -6,6 +6,7 @@ import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from configs import dify_config
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
@ -18,22 +19,24 @@ from .. import api
def get_oauth_providers():
with current_app.app_context():
github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'),
client_secret=current_app.config.get(
'GITHUB_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
github_oauth = None
else:
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
else:
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
)
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
client_secret=current_app.config.get(
'GOOGLE_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
OAUTH_PROVIDERS = {
'github': github_oauth,
'google': google_oauth
}
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
return OAUTH_PROVIDERS
@ -63,8 +66,7 @@ class OAuthCallback(Resource):
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(
f"An error occurred during the OAuth process with {provider}: {e.response.text}")
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
return {'error': 'OAuth process failed'}, 400
account = _generate_account(provider, user_info)
@ -81,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@ -101,11 +103,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Create account
account_name = user_info.name if user_info.name else 'Dify'
account = RegisterService.register(
email=user_info.email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
# Set interface language

View File

@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models.dataset import Dataset, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name):
@ -85,6 +85,12 @@ class DatasetListApi(Resource):
else:
item['embedding_available'] = True
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
else:
item.update({'partial_member_list': []})
response = {
'data': data,
'has_more': len(datasets) == limit,
@ -108,8 +114,8 @@ class DatasetListApi(Resource):
help='Invalid indexing technique.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@ -140,6 +146,10 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
@ -163,6 +173,11 @@ class DatasetApi(Resource):
data['embedding_available'] = False
else:
data['embedding_available'] = True
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
return data, 200
@setup_required
@ -188,17 +203,21 @@ class DatasetApi(Resource):
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
args = parser.parse_args()
data = request.get_json()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get('permission'), data.get('partial_member_list')
)
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
@ -206,7 +225,20 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return marshal(dataset, dataset_detail_fields), 200
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get('partial_member_list')
)
else:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({'partial_member_list': partial_member_list})
return result_data, 200
@setup_required
@login_required
@ -215,17 +247,27 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden()
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {'result': 'success'}, 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError()
class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
return {'is_using': dataset_is_using}, 200
class DatasetQueryApi(Resource):
@ -506,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
@ -530,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
@ -560,8 +602,30 @@ class DatasetErrorDocs(Resource):
}, 200
class DatasetPermissionUserListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
'data': partial_members_list,
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
@ -572,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

View File

@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args['indexing_technique'] == 'high_quality':
try:
model_manager = ModelManager()
@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
# check user's permission
DatasetService.check_dataset_permission(dataset, current_user)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
document = self.get_document(dataset_id, document_id)
indexing_cache_key = 'document_{}_indexing'.format(document.id)
cache_result = redis_client.get(indexing_cache_key)
@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()

View File

@ -19,6 +19,7 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
from flask_restful import reqparse
app_model = installed_app.app
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
message_id=message_id,
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

View File

@ -36,7 +36,7 @@ class TagListApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
def patch(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -245,6 +245,8 @@ class AccountIntegrateApi(Resource):
return {'data': integrate_data}
# Register API resources
api.add_resource(AccountInitApi, '/account/init')
api.add_resource(AccountProfileApi, '/account/profile')

View File

@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
return {'result': 'success'}
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')

View File

@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -72,19 +72,30 @@ class AudioApi(Resource):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=args['text'],
end_user=end_user,
voice=args.get('voice'),
streaming=args['streaming']
message_id=message_id,
end_user=end_user.external_user_id,
voice=voice
)
return response

View File

@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@ -69,7 +74,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -132,7 +137,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser
@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App
from models.model import App, AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
from flask_restful import reqparse
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get(
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
message_id=message_id,
end_user=end_user.external_user_id,
voice=request.form['voice'] if request.form.get('voice') else None,
streaming=False
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()

View File

@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""

View File

@ -0,0 +1,135 @@
import base64
import concurrent.futures
import logging
import queue
import re
import threading
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
class AudioTrunk:
def __init__(self, status: str, audio):
self.audio = audio
self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
)
def _process_future(future_queue, audio_queue):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def _runtime(self):
future_queue = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
except Exception as e:
self.logger.warning(e)
break
future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio
except queue.Empty:
return None
def _extract_sentence(self, org_text):
tx = self.match.finditer(org_text)
start = 0
result = []
for i in tx:
end = i.regs[0][1]
result.append(org_text[start:end])
start = end
return result, org_text[start:]

View File

@ -4,6 +4,8 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._application_generate_entity.query
)
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> ChatbotAppBlockingResponse:
-> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[ChatbotAppStreamResponse, None, None]:
-> Generator[ChatbotAppStreamResponse, None, None]:
"""
To stream response.
:return:
@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
if not self._is_stream_out_support(
event=event
event=event
):
continue
@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
if not self._task_state.current_stream_generate_state:
return None
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None

View File

@ -1,52 +1,56 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
for variable_config in variables:
variable = variable_config.variable
if (variable not in user_inputs
or user_inputs[variable] is None
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value is not None:
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
if '.' in value:
value = float(value)
else:
value = int(value)
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
if value and isinstance(value, str):
filtered_inputs[variable] = value.replace('\x00', '')
else:
filtered_inputs[variable] = value if value is not None else None
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
return user_input_value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):
return value.replace('\x00', '')
return value

View File

@ -51,7 +51,6 @@ class AppQueueManager:
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
start_time = time.time()
last_ping_time = 0
while True:
try:
message = self._q.get(timeout=1)

View File

@ -94,7 +94,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
call_depth=call_depth,
)
def _generate(
@ -104,7 +103,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -166,10 +164,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}

View File

@ -1,7 +1,10 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user)
db.session.close()
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
continue
if publisher:
publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.

View File

@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
@ -86,6 +87,8 @@ class StreamEvent(Enum):
ERROR = "error"
MESSAGE = "message"
MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought"
@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
answer: str
class MessageAudioStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str
class MessageAudioEndStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str
class MessageEndStreamResponse(StreamResponse):
"""
MessageEndStreamResponse entity
@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
"""
WorkflowStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
"""
WorkflowFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
}
}
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class IterationNodeNextStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class IterationNodeCompletedStreamResponse(StreamResponse):
"""
NodeCompletedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class TextChunkStreamResponse(StreamResponse):
"""
TextChunkStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
"""
TextReplaceStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
ChatbotAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
"""
WorkflowAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity

View File

@ -0,0 +1 @@
from .rate_limit import RateLimit

View File

@ -0,0 +1,120 @@
import logging
import time
import uuid
from collections.abc import Generator
from datetime import timedelta
from typing import Optional, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
with redis_client.pipeline() as pipe:
pipe.set(self.max_active_requests_key, self.max_active_requests)
pipe.expire(self.max_active_requests_key, timedelta(days=1))
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
self.flush_cache()
if self.max_active_requests <= 0:
return RateLimit._UNLIMITED_REQUEST_ID
if not request_id:
request_id = RateLimit.gen_request_key()
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
def exit(self, request_id: str):
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
return
redis_client.hdel(self.active_requests_key, request_id)
@staticmethod
def gen_request_key() -> str:
return str(uuid.uuid4())
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
if isinstance(generator, dict):
return generator
else:
return RateLimitGenerator(self, generator, request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()
else:
self.generator = generator
self.request_id = request_id
self.closed = False
def __iter__(self):
return self
def __next__(self):
if self.closed:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
self.close()
raise
def close(self):
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
self.generator.close()

View File

@ -4,6 +4,8 @@ import time
from collections.abc import Generator
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse,
EasyUITaskState,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config
self._conversation = conversation
self._message = message
@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation_name_generate_thread = None
def process(
self,
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._application_generate_entity.query
)
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.checkAndGetAudio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message)
event = message.event
if isinstance(event, QueueErrorEvent):
@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
"""
Save message.

View File

@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded"
class AppInvokeQuotaExceededError(Exception):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
description = "App Invoke Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support

View File

@ -1,7 +1,6 @@
import os
import requests
from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint
@ -31,10 +30,10 @@ class APIBasedExtensionRequestor:
try:
# proxy support for security
proxies = None
if os.environ.get("SSRF_PROXY_HTTP_URL") and os.environ.get("SSRF_PROXY_HTTPS_URL"):
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
'http': os.environ.get("SSRF_PROXY_HTTP_URL"),
'https': os.environ.get("SSRF_PROXY_HTTPS_URL"),
'http': dify_config.SSRF_PROXY_HTTP_URL,
'https': dify_config.SSRF_PROXY_HTTPS_URL,
}
response = requests.request(

View File

@ -186,7 +186,7 @@ class MessageFileParser:
}
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
if response.status_code in {200, 304}:
return True, ""
else:
return False, "URL does not exist."

View File

@ -1,5 +1,4 @@
import logging
import os
import time
from enum import Enum
from threading import Lock
@ -9,6 +8,7 @@ from httpx import get, post
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
# Code Executor
CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT= (10, 60)

View File

@ -730,7 +730,7 @@ class IndexingRunner:
self._check_document_paused_status(dataset_document.id)
tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
if embedding_model_instance:
tokens += sum(
embedding_model_instance.get_text_embedding_num_tokens(
[document.page_content]

View File

@ -12,7 +12,8 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
class TokenBufferMemory:
@ -30,33 +31,46 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
query = db.session.query(Message).filter(
query = db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id
).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc())
if message_limit and message_limit > 0:
messages = query.limit(message_limit).all()
message_limit = message_limit if message_limit <= 500 else 500
else:
messages = query.all()
message_limit = 500
messages = query.limit(message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(
tenant_id=app_record.tenant_id,
app_id=app_record.id
)
prompt_messages = []
for message in messages:
files = message.message_files
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
file_extra_config = FileUploadConfigManager.convert(
message.workflow_run.workflow.features_dict,
is_vision=False
)
if message.workflow_run_id:
workflow_run = (db.session.query(WorkflowRun)
.filter(WorkflowRun.id == message.workflow_run_id).first())
if workflow_run:
file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict,
is_vision=False
)
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
@ -136,4 +150,4 @@ class TokenBufferMemory:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)
return "\n".join(string_messages)

View File

@ -264,7 +264,7 @@ class ModelInstance:
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
@ -287,8 +287,7 @@ class ModelInstance:
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
voice=voice
)
def _round_robin_invoke(self, function: Callable, *args, **kwargs):

View File

@ -1,4 +1,6 @@
import hashlib
import logging
import re
import subprocess
import uuid
from abc import abstractmethod
@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class TTSModel(AIModel):
"""
Model class for ttstext model.
@ -20,7 +22,7 @@ class TTSModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
Invoke large language model
@ -35,14 +37,15 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
return self._invoke(model=model, credentials=credentials, user=user,
content_text=content_text, voice=voice, tenant_id=tenant_id)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
Invoke large language model
@ -123,26 +126,26 @@ class TTSModel(AIModel):
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
if delimiters is None:
delimiters = set('。!?;\n')
buf = []
word_count = 0
for char in text:
buf.append(char)
if char in delimiters:
if word_count >= limit:
yield ''.join(buf)
buf = []
word_count = 0
else:
word_count += 1
else:
word_count += 1
if buf:
yield ''.join(buf)
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
match = re.compile(pattern)
tx = match.finditer(org_text)
start = 0
result = []
one_sentence = ''
for i in tx:
end = i.regs[0][1]
tmp = org_text[start:end]
if len(one_sentence + tmp) > max_length:
result.append(one_sentence)
one_sentence = ''
one_sentence += tmp
start = end
last_sens = org_text[start:]
if last_sens:
one_sentence += last_sens
if one_sentence != '':
result.append(one_sentence)
return result
@staticmethod
def _is_ffmpeg_installed():

View File

@ -33,3 +33,4 @@
- deepseek
- hunyuan
- siliconflow
- perfxcloud

View File

@ -4,7 +4,7 @@ from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from flask import Response
from openai import AzureOpenAI
from pydub import AudioSegment
@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from extensions.ext_storage import storage
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
content_text: str, voice: str, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
return self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences
@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
# max font is 4096,there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
yield from response.__enter__().iter_bytes(1024)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client
if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# invoke Cohere models via boto3 client
if "cohere.command-r" in model:
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
"""
Invoke Anthropic large language model
@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
'modelId': model,
'messages': prompt_message_dicts,
'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields,
}
if system and len(system) > 0:
parameters['system'] = system
if tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if stream:
response = bedrock_client.converse_stream(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else:
response = bedrock_client.converse(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model, credentials, response, prompt_messages)
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
output_tokens = 0
finish_reason = None
index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_use = {}
for chunk in response['stream']:
if 'messageStart' in chunk:
return_model = model
elif 'messageStop' in chunk:
finish_reason = chunk['messageStop']['stopReason']
elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']
elif 'metadata' in chunk:
input_tokens = chunk['metadata']['usage']['inputTokens']
output_tokens = chunk['metadata']['usage']['outputTokens']
@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
index=index,
message=AssistantPromptMessage(
content=''
content='',
tool_calls=tool_calls
),
finish_reason=finish_reason,
usage=usage
)
)
elif 'contentBlockDelta' in chunk:
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
delta = chunk['contentBlockDelta']['delta']
if 'text' in delta:
chunk_text = delta['text'] if delta['text'] else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
)
index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index+1,
message=assistant_prompt_message,
)
)
elif 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=tool_use['input']
)
)
tool_calls.append(tool_call)
tool_use = {}
except Exception as ex:
raise InvokeError(str(ex))
@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
system = []
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
if first_loop:
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content
system.append({"text": message.content})
prompt_message_dicts = []
for message in prompt_messages:
@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return system, prompt_message_dicts
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
tool_config = {}
configs = []
if tools:
for tool in tools:
configs.append(
{
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {
"json": tool.parameters
}
}
}
)
tool_config["tools"] = configs
return tool_config
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
if message.tool_calls:
message_dict = {
"role": "assistant", "content":[{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments)
}
}]
}
else:
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = [{'text': message.content}]
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}]
}
}]
}
else:
raise ValueError(f"Got unknown type {message}")
@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
if isinstance(prompt_messages, str):
prompt = prompt_messages
else:
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -489,11 +554,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
body = content
if (isinstance(content, list)):
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt_prefix} {message.content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer'
name: 'Shimmer'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 120
word_limit: 3500
audio_type: 'mp3'
max_workers: 5
pricing:

View File

@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer'
name: 'Shimmer'
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
word_limit: 120
word_limit: 3500
audio_type: 'mp3'
max_workers: 5
pricing:

Some files were not shown because too many files have changed in this diff Show More