mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
009cb2a650
|
|
@ -1,6 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
# Ensure that .sh scripts use LF as line separator, even if they are checked out
|
||||
# to Windows(NTFS) file-system, by a user of Docker for Window.
|
||||
# These .sh scripts will be run from the Container after `docker compose up -d`.
|
||||
# If they appear to be CRLF style, Dash from the Container will fail to execute
|
||||
# them.
|
||||
|
||||
*.sh text eol=lf
|
||||
|
|
@ -1,13 +1,21 @@
|
|||
# Checklist:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
|
||||
- [ ] Please open an issue before creating a PR or link to an existing issue
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. Close issue syntax: `Fixes #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
|
||||
Fixes # (issue)
|
||||
Fixes
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
|
|
@ -15,18 +23,12 @@ Please delete options that are not relevant.
|
|||
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
|
||||
- [ ] Dependency upgrade
|
||||
|
||||
# How Has This Been Tested?
|
||||
# Testing Instructions
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
- [ ] Test A
|
||||
- [ ] Test B
|
||||
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
|
|
|
|||
|
|
@ -48,18 +48,18 @@ jobs:
|
|||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
|
|
|||
|
|
@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
|
|||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
|
||||
.idea/
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
"jinja": true,
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
|
|
|
|||
|
|
@ -192,6 +192,11 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
|||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Using Terraform for Deployment
|
||||
|
||||
##### Azure Global
|
||||
Deploy Dify to Azure with a single click using [terraform](https://www.terraform.io/).
|
||||
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -175,6 +175,12 @@ docker compose up -d
|
|||
- [رسم بياني Helm من قبل @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### استخدام Terraform للتوزيع
|
||||
|
||||
##### Azure Global
|
||||
استخدم [terraform](https://www.terraform.io/) لنشر Dify على Azure بنقرة واحدة.
|
||||
- [Azure Terraform بواسطة @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
|
|
|
|||
|
|
@ -197,6 +197,12 @@ docker compose up -d
|
|||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### 使用 Terraform 部署
|
||||
|
||||
##### Azure Global
|
||||
使用 [terraform](https://www.terraform.io/) 一键部署 Dify 到 Azure。
|
||||
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
|
|
|||
|
|
@ -199,6 +199,12 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop
|
|||
- [Gráfico Helm por @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Uso de Terraform para el despliegue
|
||||
|
||||
##### Azure Global
|
||||
Utiliza [terraform](https://www.terraform.io/) para desplegar Dify en Azure con un solo clic.
|
||||
- [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
|
||||
## Contribuir
|
||||
|
||||
|
|
|
|||
|
|
@ -197,6 +197,12 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau
|
|||
- [Helm Chart par @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Utilisation de Terraform pour le déploiement
|
||||
|
||||
##### Azure Global
|
||||
Utilisez [terraform](https://www.terraform.io/) pour déployer Dify sur Azure en un clic.
|
||||
- [Azure Terraform par @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
|
||||
## Contribuer
|
||||
|
||||
|
|
|
|||
|
|
@ -196,6 +196,12 @@ docker compose up -d
|
|||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Terraformを使用したデプロイ
|
||||
|
||||
##### Azure Global
|
||||
[terraform](https://www.terraform.io/) を使用して、AzureにDifyをワンクリックでデプロイします。
|
||||
- [nikawangのAzure Terraform](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
|
||||
## 貢献
|
||||
|
||||
|
|
|
|||
|
|
@ -197,6 +197,13 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
|||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Terraform atorlugu pilersitsineq
|
||||
|
||||
##### Azure Global
|
||||
Atoruk [terraform](https://www.terraform.io/) Dify-mik Azure-mut ataatsikkut ikkussuilluarlugu.
|
||||
- [Azure Terraform atorlugu @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
|
|
|||
|
|
@ -190,6 +190,12 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
#### Terraform을 사용한 배포
|
||||
|
||||
##### Azure Global
|
||||
[terraform](https://www.terraform.io/)을 사용하여 Azure에 Dify를 원클릭으로 배포하세요.
|
||||
- [nikawang의 Azure Terraform](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
|
|
|
|||
|
|
@ -72,6 +72,13 @@ TENCENT_COS_SECRET_ID=your-secret-id
|
|||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
|
||||
# OCI Storage configuration
|
||||
OCI_ENDPOINT=your-endpoint
|
||||
OCI_BUCKET_NAME=your-bucket-name
|
||||
OCI_ACCESS_KEY=your-access-key
|
||||
OCI_SECRET_KEY=your-secret-key
|
||||
OCI_REGION=your-region
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
|
@ -144,6 +151,16 @@ CHROMA_DATABASE=default_database
|
|||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
|
|
@ -230,4 +247,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ WORKDIR /app/api
|
|||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir --upgrade poetry==${POETRY_VERSION}
|
||||
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
|
||||
|
||||
# Configure Poetry
|
||||
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
docker compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from configs import dify_config
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
|
@ -43,6 +43,8 @@ from extensions import (
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# TODO: Find a way to avoid importing models here
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
|
@ -81,7 +83,7 @@ def create_flask_app_with_configs() -> Flask:
|
|||
with configs loaded from .env file
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(DifyConfig().model_dump())
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import click
|
|||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
|
|
@ -112,7 +113,7 @@ def reset_encrypt_key_pair():
|
|||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
|
||||
|
|
@ -336,6 +337,14 @@ def migrate_knowledge_vector_database():
|
|||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
from configs.enterprise import EnterpriseFeatureConfig
|
||||
|
|
@ -9,9 +10,6 @@ from configs.packaging import PackagingInfo
|
|||
|
||||
|
||||
class DifyConfig(
|
||||
# based on pydantic-settings
|
||||
BaseSettings,
|
||||
|
||||
# Packaging info
|
||||
PackagingInfo,
|
||||
|
||||
|
|
@ -31,12 +29,39 @@ class DifyConfig(
|
|||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
):
|
||||
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
frozen=True,
|
||||
|
||||
# ignore extra attributes
|
||||
extra='ignore',
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class DeploymentConfig(BaseModel):
|
||||
class DeploymentConfig(BaseSettings):
|
||||
"""
|
||||
Deployment configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class EnterpriseFeatureConfig(BaseModel):
|
||||
class EnterpriseFeatureConfig(BaseSettings):
|
||||
"""
|
||||
Enterprise feature configs.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
class NotionConfig(BaseSettings):
|
||||
"""
|
||||
Notion integration configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeFloat
|
||||
from pydantic import Field, NonNegativeFloat
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class SentryConfig(BaseModel):
|
||||
class SentryConfig(BaseSettings):
|
||||
"""
|
||||
Sentry configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
||||
|
||||
class SecurityConfig(BaseModel):
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
Secret Key configs
|
||||
"""
|
||||
|
|
@ -17,8 +18,12 @@ class SecurityConfig(BaseModel):
|
|||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='Expiry time in hours for reset token',
|
||||
default=24,
|
||||
)
|
||||
|
||||
class AppExecutionConfig(BaseModel):
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
"""
|
||||
|
|
@ -26,9 +31,13 @@ class AppExecutionConfig(BaseModel):
|
|||
description='execution timeout in seconds for app execution',
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseModel):
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
|
|
@ -43,36 +52,36 @@ class CodeExecutionSandboxConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseModel):
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Module URL configs
|
||||
"""
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description='The backend URL prefix of the console API.'
|
||||
'used to concatenate the login authorization callback or notion integration callback.',
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description='The front-end URL prefix of the console web.'
|
||||
'used to concatenate some front-end addresses and for CORS configuration use.',
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description='Service API Url prefix.'
|
||||
'used to display Service API Base Url to the front-end.',
|
||||
default='https://api.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description='WebApp Url prefix.'
|
||||
'used to display WebAPP API Base Url to the front-end.',
|
||||
default='https://udify.app',
|
||||
default='',
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseModel):
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
File Access configs
|
||||
"""
|
||||
|
|
@ -82,7 +91,7 @@ class FileAccessConfig(BaseModel):
|
|||
'Url is signed and has expiration time.',
|
||||
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
|
||||
alias_priority=1,
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
|
|
@ -91,7 +100,7 @@ class FileAccessConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
class FileUploadConfig(BaseSettings):
|
||||
"""
|
||||
File Uploading configs
|
||||
"""
|
||||
|
|
@ -116,7 +125,7 @@ class FileUploadConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseModel):
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP configs
|
||||
"""
|
||||
|
|
@ -136,7 +145,7 @@ class HttpConfig(BaseModel):
|
|||
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
|
||||
|
||||
inner_WEB_API_CORS_ALLOW_ORIGINS: Optional[str] = Field(
|
||||
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description='',
|
||||
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
|
||||
default='*',
|
||||
|
|
@ -148,7 +157,7 @@ class HttpConfig(BaseModel):
|
|||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseModel):
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
"""
|
||||
|
|
@ -163,7 +172,7 @@ class InnerAPIConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""
|
||||
Logging configs
|
||||
"""
|
||||
|
|
@ -195,7 +204,7 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ModelLoadBalanceConfig(BaseModel):
|
||||
class ModelLoadBalanceConfig(BaseSettings):
|
||||
"""
|
||||
Model load balance configs
|
||||
"""
|
||||
|
|
@ -205,7 +214,7 @@ class ModelLoadBalanceConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class BillingConfig(BaseModel):
|
||||
class BillingConfig(BaseSettings):
|
||||
"""
|
||||
Platform Billing Configurations
|
||||
"""
|
||||
|
|
@ -215,7 +224,7 @@ class BillingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class UpdateConfig(BaseModel):
|
||||
class UpdateConfig(BaseSettings):
|
||||
"""
|
||||
Update configs
|
||||
"""
|
||||
|
|
@ -225,7 +234,7 @@ class UpdateConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseModel):
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Workflow feature configs
|
||||
"""
|
||||
|
|
@ -246,7 +255,7 @@ class WorkflowConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseModel):
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
"""
|
||||
|
|
@ -276,7 +285,7 @@ class OAuthConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseModel):
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
Moderation in app configs.
|
||||
"""
|
||||
|
|
@ -288,7 +297,7 @@ class ModerationConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
class ToolConfig(BaseSettings):
|
||||
"""
|
||||
Tool configs
|
||||
"""
|
||||
|
|
@ -299,7 +308,7 @@ class ToolConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class MailConfig(BaseModel):
|
||||
class MailConfig(BaseSettings):
|
||||
"""
|
||||
Mail Configurations
|
||||
"""
|
||||
|
|
@ -355,7 +364,7 @@ class MailConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseModel):
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
RAG ETL Configurations.
|
||||
"""
|
||||
|
|
@ -381,7 +390,7 @@ class RagEtlConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class DataSetConfig(BaseModel):
|
||||
class DataSetConfig(BaseSettings):
|
||||
"""
|
||||
Dataset configs
|
||||
"""
|
||||
|
|
@ -391,8 +400,13 @@ class DataSetConfig(BaseModel):
|
|||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
default=False,
|
||||
)
|
||||
|
||||
class WorkspaceConfig(BaseModel):
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
"""
|
||||
|
|
@ -403,7 +417,7 @@ class WorkspaceConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class IndexingConfig(BaseModel):
|
||||
class IndexingConfig(BaseSettings):
|
||||
"""
|
||||
Indexing configs.
|
||||
"""
|
||||
|
|
@ -414,7 +428,7 @@ class IndexingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseModel):
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description='multi model send image format, support base64, url, default is base64',
|
||||
default='base64',
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt
|
||||
from pydantic import Field, NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HostedOpenAiConfig(BaseModel):
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
"""
|
||||
|
|
@ -68,7 +69,7 @@ class HostedOpenAiConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedAzureOpenAiConfig(BaseModel):
|
||||
class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
"""
|
||||
|
|
@ -94,7 +95,7 @@ class HostedAzureOpenAiConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedAnthropicConfig(BaseModel):
|
||||
class HostedAnthropicConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Azure OpenAI service config
|
||||
"""
|
||||
|
|
@ -125,7 +126,7 @@ class HostedAnthropicConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseModel):
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
"""
|
||||
|
|
@ -136,7 +137,7 @@ class HostedMinmaxConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedSparkConfig(BaseModel):
|
||||
class HostedSparkConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Spark service config
|
||||
"""
|
||||
|
|
@ -147,7 +148,7 @@ class HostedSparkConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedZhipuAIConfig(BaseModel):
|
||||
class HostedZhipuAIConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
"""
|
||||
|
|
@ -158,7 +159,7 @@ class HostedZhipuAIConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
class HostedModerationConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
"""
|
||||
|
|
@ -174,7 +175,7 @@ class HostedModerationConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class HostedFetchAppTemplateConfig(BaseModel):
|
||||
class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
|
|
@ -21,7 +24,7 @@ from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
|||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description='storage type,'
|
||||
' default to `local`,'
|
||||
|
|
@ -35,14 +38,14 @@ class StorageConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description='vector store type',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseModel):
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description='keyword store type',
|
||||
default='jieba',
|
||||
|
|
@ -80,6 +83,11 @@ class DatabaseConfig:
|
|||
default='',
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description='db extras options. Example: keepalives_idle=60&keepalives=1',
|
||||
default='',
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description='db uri scheme',
|
||||
default='postgresql',
|
||||
|
|
@ -88,7 +96,12 @@ class DatabaseConfig:
|
|||
@computed_field
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = f"?client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else ""
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
|
||||
if self.DB_CHARSET
|
||||
else self.DB_EXTRAS
|
||||
).strip("&")
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}")
|
||||
|
|
@ -113,7 +126,7 @@ class DatabaseConfig:
|
|||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool = Field(
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description='whether to enable SqlAlchemy echo',
|
||||
default=False,
|
||||
)
|
||||
|
|
@ -143,7 +156,7 @@ class CeleryConfig(DatabaseConfig):
|
|||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str:
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
|
||||
|
|
@ -167,9 +180,11 @@ class MiddlewareConfig(
|
|||
GoogleCloudStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
MilvusConfig,
|
||||
OpenSearchConfig,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RedisConfig(BaseModel):
|
||||
class RedisConfig(BaseSettings):
|
||||
"""
|
||||
Redis configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AliyunOSSStorageConfig(BaseModel):
|
||||
class AliyunOSSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Aliyun storage configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class S3StorageConfig(BaseModel):
|
||||
class S3StorageConfig(BaseSettings):
|
||||
"""
|
||||
S3 storage configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AzureBlobStorageConfig(BaseModel):
|
||||
class AzureBlobStorageConfig(BaseSettings):
|
||||
"""
|
||||
Azure Blob storage configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class GoogleCloudStorageConfig(BaseModel):
|
||||
class GoogleCloudStorageConfig(BaseSettings):
|
||||
"""
|
||||
Google Cloud storage configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OCIStorageConfig(BaseSettings):
|
||||
"""
|
||||
OCI storage configs
|
||||
"""
|
||||
|
||||
OCI_ENDPOINT: Optional[str] = Field(
|
||||
description='OCI storage endpoint',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_REGION: Optional[str] = Field(
|
||||
description='OCI storage region',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||
description='OCI storage bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||
description='OCI storage access key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_SECRET_KEY: Optional[str] = Field(
|
||||
description='OCI storage secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TencentCloudCOSStorageConfig(BaseModel):
|
||||
class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Cloud COS storage configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ChromaConfig(BaseModel):
|
||||
class ChromaConfig(BaseSettings):
|
||||
"""
|
||||
Chroma configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
class MilvusConfig(BaseSettings):
|
||||
"""
|
||||
Milvus configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
OpenSearch configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OracleConfig(BaseModel):
|
||||
class OracleConfig(BaseSettings):
|
||||
"""
|
||||
ORACLE configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
class PGVectorConfig(BaseSettings):
|
||||
"""
|
||||
PGVector configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PGVectoRSConfig(BaseModel):
|
||||
class PGVectoRSConfig(BaseSettings):
|
||||
"""
|
||||
PGVectoRS configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
class QdrantConfig(BaseSettings):
|
||||
"""
|
||||
Qdrant configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
class RelytConfig(BaseSettings):
|
||||
"""
|
||||
Relyt configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TencentVectorDBConfig(BaseModel):
|
||||
class TencentVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Vector configs
|
||||
"""
|
||||
|
|
@ -24,7 +25,7 @@ class TencentVectorDBConfig(BaseModel):
|
|||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description='Tencent Vector password',
|
||||
description='Tencent Vector username',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
@ -38,7 +39,12 @@ class TencentVectorDBConfig(BaseModel):
|
|||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: PositiveInt = Field(
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description='Tencent Vector replicas',
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description='Tencent Vector Database',
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TiDBVectorConfig(BaseModel):
|
||||
class TiDBVectorConfig(BaseSettings):
|
||||
"""
|
||||
TiDB Vector configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
class WeaviateConfig(BaseSettings):
|
||||
"""
|
||||
Weaviate configs
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PackagingInfo(BaseModel):
|
||||
class PackagingInfo(BaseSettings):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.6.12-fix1',
|
||||
default='0.6.13',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ language_timezone_mapping = {
|
|||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
'ro-RO': 'Europe/Bucharest',
|
||||
'pl-PL': 'Europe/Warsaw',
|
||||
'hi-IN': 'Asia/Kolkata'
|
||||
'hi-IN': 'Asia/Kolkata',
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,4 @@
|
|||
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
||||
|
|
@ -30,7 +30,7 @@ from .app import (
|
|||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ class AppApi(Resource):
|
|||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
|
@ -190,6 +191,10 @@ class AppExportApi(Resource):
|
|||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'],
|
||||
streaming=False
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
|
|||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
|
|||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
|
|
@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from flask_login import current_user
|
|||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
|
@ -16,11 +17,11 @@ from ..wraps import account_initialization_required
|
|||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
if not dify_config.NOTION_CLIENT_ID or not dify_config.NOTION_CLIENT_SECRET:
|
||||
return {}
|
||||
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
|
|
@ -39,8 +40,10 @@ class OAuthDataSource(Resource):
|
|||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
|
||||
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
||||
if not internal_secret:
|
||||
return {'error': 'Internal secret is not set'},
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return { 'data': '' }
|
||||
else:
|
||||
|
|
@ -60,13 +63,13 @@ class OAuthDataSourceCallback(Resource):
|
|||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
|
|
|
|||
|
|
@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
|
|||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class InvalidEmailError(BaseHTTPException):
|
||||
error_code = 'invalid_email'
|
||||
description = "The email address is not valid."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordMismatchError(BaseHTTPException):
|
||||
error_code = 'password_mismatch'
|
||||
description = "The passwords do not match."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidTokenError(BaseHTTPException):
|
||||
error_code = 'invalid_or_expired_token'
|
||||
description = "The token is invalid or has expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = 'password_reset_rate_limit_exceeded'
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
code = 429
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
PasswordResetRateLimitExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args['email']
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if account:
|
||||
try:
|
||||
AccountService.send_reset_password_email(account=account)
|
||||
except RateLimitExceededError:
|
||||
logging.warning(f"Rate limit exceeded for email: {account.email}")
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
else:
|
||||
# Return success to avoid revealing email registration status
|
||||
logging.warning(f"Attempt to reset password for unregistered email: {email}")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
token = args['token']
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
return {'is_valid': False, 'email': None}
|
||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args['new_password']
|
||||
password_confirm = args['password_confirm']
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args['token']
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_reset_password_token(token)
|
||||
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -56,14 +56,14 @@ class LogoutApi(Resource):
|
|||
class ResetPasswordApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=email, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
# parser = reqparse.RequestParser()
|
||||
# parser.add_argument('email', type=email, required=True, location='json')
|
||||
# args = parser.parse_args()
|
||||
|
||||
# import mailchimp_transactional as MailchimpTransactional
|
||||
# from mailchimp_transactional.api_client import ApiClientError
|
||||
|
||||
account = {'email': args['email']}
|
||||
# account = {'email': args['email']}
|
||||
# account = AccountService.get_by_email(args['email'])
|
||||
# if account is None:
|
||||
# raise ValueError('Email not found')
|
||||
|
|
@ -71,22 +71,22 @@ class ResetPasswordApi(Resource):
|
|||
# AccountService.update_password(account, new_password)
|
||||
|
||||
# todo: Send email
|
||||
MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
|
||||
# MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
|
||||
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
|
||||
|
||||
message = {
|
||||
'from_email': 'noreply@example.com',
|
||||
'to': [{'email': account.email}],
|
||||
'subject': 'Reset your Dify password',
|
||||
'html': """
|
||||
<p>Dear User,</p>
|
||||
<p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
<p><strong>{new_password}</strong></p>
|
||||
<p>Please change your password to log in as soon as possible.</p>
|
||||
<p>Regards,</p>
|
||||
<p>The Dify Team</p>
|
||||
"""
|
||||
}
|
||||
# message = {
|
||||
# 'from_email': 'noreply@example.com',
|
||||
# 'to': [{'email': account['email']}],
|
||||
# 'subject': 'Reset your Dify password',
|
||||
# 'html': """
|
||||
# <p>Dear User,</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p><strong>{new_password}</strong></p>
|
||||
# <p>Please change your password to log in as soon as possible.</p>
|
||||
# <p>Regards,</p>
|
||||
# <p>The Dify Team</p>
|
||||
# """
|
||||
# }
|
||||
|
||||
# response = mailchimp.messages.send({
|
||||
# 'message': message,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import requests
|
|||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import get_remote_ip
|
||||
|
|
@ -18,22 +19,24 @@ from .. import api
|
|||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
|
||||
github_oauth = None
|
||||
else:
|
||||
github_oauth = GitHubOAuth(
|
||||
client_id=dify_config.GITHUB_CLIENT_ID,
|
||||
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
|
||||
)
|
||||
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
||||
google_oauth = None
|
||||
else:
|
||||
google_oauth = GoogleOAuth(
|
||||
client_id=dify_config.GOOGLE_CLIENT_ID,
|
||||
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
|
||||
)
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
'google': google_oauth
|
||||
}
|
||||
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
|
|
@ -63,8 +66,7 @@ class OAuthCallback(Resource):
|
|||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
|
||||
return {'error': 'OAuth process failed'}, 400
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
|
|
@ -81,7 +83,7 @@ class OAuthCallback(Resource):
|
|||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
|
@ -101,11 +103,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
# Create account
|
||||
account_name = user_info.name if user_info.name else 'Dify'
|
||||
account = RegisterService.register(
|
||||
email=user_info.email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
|
@ -85,6 +85,12 @@ class DatasetListApi(Resource):
|
|||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
else:
|
||||
item.update({'partial_member_list': []})
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
|
|
@ -108,8 +114,8 @@ class DatasetListApi(Resource):
|
|||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
|
|
@ -140,6 +146,10 @@ class DatasetApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
|
|
@ -163,6 +173,11 @@ class DatasetApi(Resource):
|
|||
data['embedding_available'] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
|
|
@ -188,17 +203,21 @@ class DatasetApi(Resource):
|
|||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
|
|
@ -206,7 +225,20 @@ class DatasetApi(Resource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
else:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -215,17 +247,27 @@ class DatasetApi(Resource):
|
|||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
raise DatasetInUseError()
|
||||
|
||||
class DatasetUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||
return {'is_using': dataset_is_using}, 200
|
||||
|
||||
class DatasetQueryApi(Resource):
|
||||
|
||||
|
|
@ -506,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
|
|
@ -530,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
|
|
@ -560,8 +602,30 @@ class DatasetErrorDocs(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
|
|
@ -572,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
|||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
|
|||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
|
|
@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
|
|||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
|
|
@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
# check user's permission
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
|
|
@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
|
|||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from controllers.console.app.error import (
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
|
@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
|
|||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
from flask_restful import reqparse
|
||||
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
|
@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
|
|||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
|
|||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -245,6 +245,8 @@ class AccountIntegrateApi(Resource):
|
|||
return {'data': integrate_data}
|
||||
|
||||
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, '/account/init')
|
||||
api.add_resource(AccountProfileApi, '/account/profile')
|
||||
|
|
|
|||
|
|
@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {'result': 'success', 'accounts': members}, 200
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, '/workspaces/current/members')
|
||||
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
|
||||
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
|
||||
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
|
||||
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
|||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
|
@ -72,19 +72,30 @@ class AudioApi(Resource):
|
|||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=args.get('voice'),
|
||||
streaming=args['streaming']
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
|
|||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -69,7 +74,7 @@ class CompletionApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
@ -132,7 +137,7 @@ class ChatApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
|
|||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
|
@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from controllers.web.error import (
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App
|
||||
from models.model import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
|
@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
|
|||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restful import reqparse
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get(
|
||||
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form.get('voice') else None,
|
||||
streaming=False
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
|
|
|||
|
|
@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
|
|||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.variable
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
||||
|
|
@ -4,6 +4,8 @@ import time
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
|
|
@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
|
|||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
|
|
@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
|
|
@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
|
|
@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
-> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
|
|
@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
|
@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
|
|
@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
|
|
@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
|
|
@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
|
|
@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
|
|
@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
|
|
@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return None
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
|
@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
|
@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,52 +1,56 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if (variable not in user_inputs
|
||||
or user_inputs[variable] is None
|
||||
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
|
||||
if variable_config.required:
|
||||
raise ValueError(f"{variable} is required in input form")
|
||||
else:
|
||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value is not None:
|
||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
if value and isinstance(value, str):
|
||||
filtered_inputs[variable] = value.replace('\x00', '')
|
||||
else:
|
||||
filtered_inputs[variable] = value if value is not None else None
|
||||
|
||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.name)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.name} is required in input form')
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.SELECT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
)
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
return float(user_input_value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||
if var.type == VariableEntity.Type.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||
|
||||
return user_input_value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace('\x00', '')
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ class AppQueueManager:
|
|||
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
|
|
@ -104,7 +103,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -166,10 +164,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
|
|
@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
|
|||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
|
|
@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
|
|
@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
|
|
@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
|
|
@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
|
|||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
|
|
@ -86,6 +87,8 @@ class StreamEvent(Enum):
|
|||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
|
|
@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
|
|||
answer: str
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageAudioEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageEndStreamResponse entity
|
||||
|
|
@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
WorkflowStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
"""
|
||||
WorkflowFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeCompletedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
|||
"""
|
||||
TextReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
ChatbotAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
WorkflowAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .rate_limit import RateLimit
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
with redis_client.pipeline() as pipe:
|
||||
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
self.flush_cache()
|
||||
if self.max_active_requests <= 0:
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if not request_id:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
def exit(self, request_id: str):
|
||||
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||
return
|
||||
redis_client.hdel(self.active_requests_key, request_id)
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.closed:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
self.generator.close()
|
||||
|
|
@ -4,6 +4,8 @@ import time
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
|
|
@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
|
|||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
|
|
@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
|
|
@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
|
|
@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
|
|
@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
|
|
@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
"""
|
||||
Save message.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
|
|||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from configs import dify_config
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
|
|
@ -31,10 +30,10 @@ class APIBasedExtensionRequestor:
|
|||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
if os.environ.get("SSRF_PROXY_HTTP_URL") and os.environ.get("SSRF_PROXY_HTTPS_URL"):
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxies = {
|
||||
'http': os.environ.get("SSRF_PROXY_HTTP_URL"),
|
||||
'https': os.environ.get("SSRF_PROXY_HTTPS_URL"),
|
||||
'http': dify_config.SSRF_PROXY_HTTP_URL,
|
||||
'https': dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class MessageFileParser:
|
|||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
|
|
@ -9,6 +8,7 @@ from httpx import get, post
|
|||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
|
|
@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code Executor
|
||||
CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
|
||||
CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT= (10, 60)
|
||||
|
||||
|
|
|
|||
|
|
@ -730,7 +730,7 @@ class IndexingRunner:
|
|||
self._check_document_paused_status(dataset_document.id)
|
||||
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
if embedding_model_instance:
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens(
|
||||
[document.page_content]
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from core.model_runtime.entities.message_entities import (
|
|||
UserPromptMessage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, Message
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
|
|
@ -30,33 +31,46 @@ class TokenBufferMemory:
|
|||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = db.session.query(Message).filter(
|
||||
query = db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id
|
||||
).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer != ''
|
||||
).order_by(Message.created_at.desc())
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
messages = query.limit(message_limit).all()
|
||||
message_limit = message_limit if message_limit <= 500 else 500
|
||||
else:
|
||||
messages = query.all()
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id
|
||||
)
|
||||
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
files = message.message_files
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
message.workflow_run.workflow.features_dict,
|
||||
is_vision=False
|
||||
)
|
||||
if message.workflow_run_id:
|
||||
workflow_run = (db.session.query(WorkflowRun)
|
||||
.filter(WorkflowRun.id == message.workflow_run_id).first())
|
||||
|
||||
if workflow_run:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict,
|
||||
is_vision=False
|
||||
)
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
|
|
@ -136,4 +150,4 @@ class TokenBufferMemory:
|
|||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
return "\n".join(string_messages)
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class ModelInstance:
|
|||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
|
@ -287,8 +287,7 @@ class ModelInstance:
|
|||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
voice=voice
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
|
@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
|||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
|
|
@ -20,7 +22,7 @@ class TTSModel(AIModel):
|
|||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
|
@ -35,14 +37,15 @@ class TTSModel(AIModel):
|
|||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
|
||||
self._is_ffmpeg_installed()
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
||||
return self._invoke(model=model, credentials=credentials, user=user,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
|
@ -123,26 +126,26 @@ class TTSModel(AIModel):
|
|||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ''
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ''
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != '':
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
|
|
|
|||
|
|
@ -33,3 +33,4 @@
|
|||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from functools import reduce
|
|||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
|
@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
|
@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
|
|
@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
|
|
@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
|
|
@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
|
@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
|
|
@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ model_type: llm
|
|||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ model_type: llm
|
|||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ model_type: llm
|
|||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ model_type: llm
|
|||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
|
|
@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
# TODO: consolidate different invocation methods for models based on base model capabilities
|
||||
# invoke anthropic models via boto3 client
|
||||
if "anthropic" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
# invoke Cohere models via boto3 client
|
||||
if "cohere.command-r" in model:
|
||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
|
|
@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
|
||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Anthropic large language model
|
||||
|
||||
|
|
@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
parameters = {
|
||||
'modelId': model,
|
||||
'messages': prompt_message_dicts,
|
||||
'inferenceConfig': inference_config,
|
||||
'additionalModelRequestFields': additional_model_fields,
|
||||
}
|
||||
|
||||
if system and len(system) > 0:
|
||||
parameters['system'] = system
|
||||
|
||||
if tools:
|
||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(
|
||||
modelId=model,
|
||||
messages=prompt_message_dicts,
|
||||
system=system,
|
||||
inferenceConfig=inference_config,
|
||||
additionalModelRequestFields=additional_model_fields
|
||||
)
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = bedrock_client.converse(
|
||||
modelId=model,
|
||||
messages=prompt_message_dicts,
|
||||
system=system,
|
||||
inferenceConfig=inference_config,
|
||||
additionalModelRequestFields=additional_model_fields
|
||||
)
|
||||
response = bedrock_client.converse(**parameters)
|
||||
return self._handle_converse_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||
|
|
@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_use = {}
|
||||
|
||||
for chunk in response['stream']:
|
||||
if 'messageStart' in chunk:
|
||||
return_model = model
|
||||
elif 'messageStop' in chunk:
|
||||
finish_reason = chunk['messageStop']['stopReason']
|
||||
elif 'contentBlockStart' in chunk:
|
||||
tool = chunk['contentBlockStart']['start']['toolUse']
|
||||
tool_use['toolUseId'] = tool['toolUseId']
|
||||
tool_use['name'] = tool['name']
|
||||
elif 'metadata' in chunk:
|
||||
input_tokens = chunk['metadata']['usage']['inputTokens']
|
||||
output_tokens = chunk['metadata']['usage']['outputTokens']
|
||||
|
|
@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content=''
|
||||
content='',
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
elif 'contentBlockDelta' in chunk:
|
||||
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
index = chunk['contentBlockDelta']['contentBlockIndex']
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
delta = chunk['contentBlockDelta']['delta']
|
||||
if 'text' in delta:
|
||||
chunk_text = delta['text'] if delta['text'] else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
)
|
||||
index = chunk['contentBlockDelta']['contentBlockIndex']
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index+1,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
elif 'toolUse' in delta:
|
||||
if 'input' not in tool_use:
|
||||
tool_use['input'] = ''
|
||||
tool_use['input'] += delta['toolUse']['input']
|
||||
elif 'contentBlockStop' in chunk:
|
||||
if 'input' in tool_use:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=tool_use['input']
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
tool_use = {}
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
|
|
@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
|
||||
system = []
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
system.append({"text": message.content})
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
|
|
@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
|
||||
tool_config = {}
|
||||
configs = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
configs.append(
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": {
|
||||
"json": tool.parameters
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
tool_config["tools"] = configs
|
||||
return tool_config
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
|
|
@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
||||
if message.tool_calls:
|
||||
message_dict = {
|
||||
"role": "assistant", "content":[{
|
||||
"toolUse": {
|
||||
"toolUseId": message.tool_calls[0].id,
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"input": json.loads(message.tool_calls[0].function.arguments)
|
||||
}
|
||||
}]
|
||||
}
|
||||
else:
|
||||
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = [{'text': message.content}]
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"toolResult": {
|
||||
"toolUseId": message.tool_call_id,
|
||||
"content": [{"json": {"text": message.content}}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
|
@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
if isinstance(prompt_messages, str):
|
||||
prompt = prompt_messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
|
||||
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
|
@ -489,11 +554,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
content = message.content
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
|
||||
body = content
|
||||
if (isinstance(content, list)):
|
||||
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
|
||||
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt_prefix} {message.content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ model_properties:
|
|||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ model_properties:
|
|||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue