mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/model-auth
This commit is contained in:
commit
6aa5273c5e
1197
.env.example
1197
.env.example
File diff suppressed because it is too large
Load Diff
|
|
@ -82,7 +82,7 @@ jobs:
|
|||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
|
|
@ -95,10 +95,12 @@ jobs:
|
|||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint
|
||||
|
||||
docker-compose-template:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ jobs:
|
|||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
|
|
@ -59,10 +59,12 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Create Pull Request
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
|
|
@ -48,8 +48,10 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
|
|
|
|||
|
|
@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info
|
|||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
.idea/
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --project api mypy . # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
1. Create controller in `/api/controllers/`
|
||||
2. Add service logic in `/api/services/`
|
||||
3. Update routes in controller's `__init__.py`
|
||||
4. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
|
|
@ -225,7 +225,8 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
##### AWS
|
||||
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Using Alibaba Cloud Computing Nest
|
||||
|
||||
|
|
|
|||
|
|
@ -208,7 +208,8 @@ docker compose up -d
|
|||
|
||||
##### AWS
|
||||
|
||||
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK بواسطة @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK بواسطة @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### استخدام Alibaba Cloud للنشر
|
||||
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
|
|
|||
|
|
@ -225,7 +225,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
|||
|
||||
##### AWS
|
||||
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud ব্যবহার করে ডিপ্লয়
|
||||
|
||||
|
|
|
|||
|
|
@ -223,7 +223,8 @@ docker compose up -d
|
|||
使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### 使用 阿里云计算巢 部署
|
||||
|
||||
|
|
|
|||
|
|
@ -220,7 +220,8 @@ Stellen Sie Dify mit nur einem Klick mithilfe von [terraform](https://www.terraf
|
|||
Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -220,7 +220,8 @@ Despliega Dify en una plataforma en la nube con un solo clic utilizando [terrafo
|
|||
Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -218,7 +218,8 @@ Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](http
|
|||
Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK par @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK par @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ docker compose up -d
|
|||
[CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします
|
||||
|
||||
##### AWS
|
||||
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [@KevinZhaoによるAWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [@tmokmssによるAWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
|
|
|||
|
|
@ -218,7 +218,8 @@ wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/
|
|||
wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH.
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK qachlot @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK qachlot @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -212,7 +212,8 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
[CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포
|
||||
|
||||
##### AWS
|
||||
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [KevinZhao의 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [tmokmss의 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -217,7 +217,8 @@ Implante o Dify na Plataforma Cloud com um único clique usando [terraform](http
|
|||
Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -218,7 +218,8 @@ namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.
|
|||
Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -211,7 +211,8 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
|||
[CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK tarafından @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK tarafından @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
|
|
|
|||
|
|
@ -223,7 +223,8 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
### AWS
|
||||
|
||||
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [由 @KevinZhao 提供的 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [由 @tmokmss 提供的 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
#### 使用 阿里云计算巢進行部署
|
||||
|
||||
|
|
|
|||
|
|
@ -213,7 +213,8 @@ Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột b
|
|||
Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK bởi @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
- [AWS CDK bởi @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
|
||||
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
|
|
|||
|
|
@ -42,6 +42,15 @@ REDIS_PORT=6379
|
|||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
# SSL configuration for Redis (when REDIS_USE_SSL=true)
|
||||
REDIS_SSL_CERT_REQS=CERT_NONE
|
||||
# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
|
||||
REDIS_SSL_CA_CERTS=
|
||||
# Path to CA certificate file for SSL verification
|
||||
REDIS_SSL_CERTFILE=
|
||||
# Path to client certificate file for SSL authentication
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
|
||||
# redis Sentinel configuration.
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
|
|
@ -67,6 +68,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_logging,
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
|
|
|||
136
api/commands.py
136
api/commands.py
|
|
@ -36,6 +36,7 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
|||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
|
|
@ -1202,3 +1203,138 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of orphaned app IDs to return
|
||||
|
||||
Returns:
|
||||
List of app IDs that have draft variables but don't exist in the apps table
|
||||
"""
|
||||
query = """
|
||||
SELECT DISTINCT wdv.app_id
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(query), {"batch_size": batch_size})
|
||||
return [row[0] for row in result]
|
||||
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about orphaned variables
|
||||
"""
|
||||
query = """
|
||||
SELECT
|
||||
wdv.app_id,
|
||||
COUNT(*) as variable_count
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
)
|
||||
GROUP BY wdv.app_id
|
||||
ORDER BY variable_count DESC
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(query))
|
||||
orphaned_by_app = {row[0]: row[1] for row in result}
|
||||
|
||||
total_orphaned = sum(orphaned_by_app.values())
|
||||
app_count = len(orphaned_by_app)
|
||||
|
||||
return {
|
||||
"total_orphaned_variables": total_orphaned,
|
||||
"orphaned_app_count": app_count,
|
||||
"orphaned_by_app": orphaned_by_app,
|
||||
}
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting")
|
||||
@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)")
|
||||
@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)")
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
def cleanup_orphaned_draft_variables(
|
||||
dry_run: bool,
|
||||
batch_size: int,
|
||||
max_apps: int | None,
|
||||
force: bool = False,
|
||||
):
|
||||
"""
|
||||
Clean up orphaned draft variables from the database.
|
||||
|
||||
This script finds and removes draft variables that belong to apps
|
||||
that no longer exist in the database.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get statistics
|
||||
stats = _count_orphaned_draft_variables()
|
||||
|
||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||
|
||||
if stats["total_orphaned_variables"] == 0:
|
||||
logger.info("No orphaned draft variables found. Exiting.")
|
||||
return
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables", app_id, count)
|
||||
if len(stats["orphaned_by_app"]) > 10:
|
||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||
return
|
||||
|
||||
# Confirm deletion
|
||||
if not force:
|
||||
click.confirm(
|
||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
total_deleted = 0
|
||||
processed_apps = 0
|
||||
|
||||
while True:
|
||||
if max_apps and processed_apps >= max_apps:
|
||||
logger.info("Reached maximum app limit (%s). Stopping.", max_apps)
|
||||
break
|
||||
|
||||
orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10)
|
||||
if not orphaned_app_ids:
|
||||
logger.info("No more orphaned draft variables found.")
|
||||
break
|
||||
|
||||
for app_id in orphaned_app_ids:
|
||||
if max_apps and processed_apps >= max_apps:
|
||||
break
|
||||
|
||||
try:
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size)
|
||||
total_deleted += deleted_count
|
||||
processed_apps += 1
|
||||
|
||||
logger.info("Deleted %s variables for app %s", deleted_count, app_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing app %s", app_id)
|
||||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,26 @@ class RedisConfig(BaseSettings):
|
|||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SSL_CERT_REQS: str = Field(
|
||||
description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)",
|
||||
default="CERT_NONE",
|
||||
)
|
||||
|
||||
REDIS_SSL_CA_CERTS: Optional[str] = Field(
|
||||
description="Path to the CA certificate file for SSL verification",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SSL_CERTFILE: Optional[str] = Field(
|
||||
description="Path to the client certificate file for SSL authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SSL_KEYFILE: Optional[str] = Field(
|
||||
description="Path to the client private key file for SSL authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Enable Redis Sentinel mode for high availability",
|
||||
default=False,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
|
|
@ -9,3 +11,27 @@ class FilenameNotExistsError(HTTPException):
|
|||
class RemoteFileUploadError(HTTPException):
|
||||
code = 400
|
||||
description = "Error uploading remote file."
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@ from flask_login import current_user
|
|||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
|
@ -24,6 +25,7 @@ from core.errors.error import (
|
|||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -115,6 +117,10 @@ class ChatMessageApi(Resource):
|
|||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -79,18 +79,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
|||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
|
|
@ -10,6 +12,9 @@ from controllers.console.app.error import (
|
|||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
|
|
@ -107,6 +112,119 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
return structured_output
|
||||
|
||||
|
||||
class InstructionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).filter(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self) -> dict:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT}
|
||||
case "code":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE
|
||||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||
api.add_resource(InstructionGenerateApi, "/instruction-generate")
|
||||
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
|
|||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
|
@ -124,17 +124,34 @@ class MessageFeedbackApi(Resource):
|
|||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=str(args["message_id"]),
|
||||
user=current_user,
|
||||
rating=args.get("rating"),
|
||||
content=None,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
|
|
@ -185,6 +186,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
|
|
@ -373,6 +378,10 @@ class DraftWorkflowRunApi(Resource):
|
|||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
|
|
|
|||
|
|
@ -163,11 +163,11 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
|
|
|
|||
|
|
@ -1,30 +1,6 @@
|
|||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = "dataset_not_initialized"
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
|
|
|
|||
|
|
@ -76,30 +76,6 @@ class EmailSendIpLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
|
|
|
|||
|
|
@ -8,7 +8,13 @@ from werkzeug.exceptions import Forbidden
|
|||
import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
|
|
@ -18,13 +24,6 @@ from fields.file_fields import file_fields, upload_config_fields
|
|||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from .error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,18 +7,17 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.common.errors import (
|
||||
FileTooLargeError,
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from .error import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class VersionApi(Resource):
|
|||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
||||
except Exception as error:
|
||||
logging.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.get("current_version")
|
||||
|
|
|
|||
|
|
@ -7,15 +7,15 @@ from sqlalchemy import select
|
|||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.datasets.error import (
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
|
@ -5,8 +5,8 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from flask import Response
|
|||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
|
|
|||
|
|
@ -5,11 +5,13 @@ from flask_restful import Resource, marshal_with
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from controllers.service_api.app.error import FileTooLargeError
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from fields.file_fields import file_fields
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import json
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -136,12 +134,15 @@ class ConversationVariableDetailApi(Resource):
|
|||
variable_id = str(variable_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("value", required=True, location="json")
|
||||
# using lambda is for passing the already-typed value without modification
|
||||
# if no lambda, it will be converted to string
|
||||
# the string cannot be converted using json.loads
|
||||
parser.add_argument("value", required=True, location="json", type=lambda x: x)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
|
||||
app_model, conversation_id, variable_id, end_user, args["value"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
|||
|
|
@ -85,30 +85,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
|||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class FileNotFoundError(BaseHTTPException):
|
||||
error_code = "file_not_found"
|
||||
description = "The requested file was not found."
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ from flask import request
|
|||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@ from sqlalchemy import desc, select
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
ProviderNotInitializeError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
|
|
|
|||
|
|
@ -1,30 +1,6 @@
|
|||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = "dataset_not_initialized"
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
|
|
|
|||
|
|
@ -97,30 +97,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
|||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class WebAppAuthRequiredError(BaseHTTPException):
|
||||
error_code = "web_sso_auth_required"
|
||||
description = "Web app authentication required."
|
||||
|
|
|
|||
|
|
@ -2,8 +2,13 @@ from flask import request
|
|||
from flask_restful import marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
|
|
|||
|
|
@ -5,15 +5,17 @@ from flask_restful import marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.common.errors import (
|
||||
FileTooLargeError,
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
from .error import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ from core.workflow.system_variable import SystemVariable
|
|||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -896,6 +897,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
|
|
|
|||
|
|
@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
|
|||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
|
|
@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return MessageQueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event,
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -389,6 +390,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
message.updated_at = naive_utc_now()
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from base64 import b64encode
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.utils import SegmentJSONEncoder
|
||||
from core.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
|
|
@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
|
|||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
|
|
|||
|
|
@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]:
|
|||
"""
|
||||
Retrieve the trace_id from the request.
|
||||
|
||||
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
|
||||
Priority:
|
||||
1. header ('X-Trace-Id')
|
||||
2. parameters
|
||||
3. JSON body
|
||||
4. Current OpenTelemetry context (if enabled)
|
||||
5. OpenTelemetry traceparent header (if present and valid)
|
||||
|
||||
Returns None if no valid trace_id is provided.
|
||||
"""
|
||||
trace_id = request.headers.get("X-Trace-Id")
|
||||
|
||||
if not trace_id:
|
||||
trace_id = request.args.get("trace_id")
|
||||
|
||||
if not trace_id and getattr(request, "is_json", False):
|
||||
json_data = getattr(request, "json", None)
|
||||
if json_data:
|
||||
trace_id = json_data.get("trace_id")
|
||||
|
||||
if not trace_id:
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
|
||||
if not trace_id:
|
||||
traceparent = request.headers.get("traceparent")
|
||||
if traceparent:
|
||||
trace_id = parse_traceparent_header(traceparent)
|
||||
|
||||
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
|
||||
return trace_id
|
||||
return None
|
||||
|
|
@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
|
|||
if trace_id:
|
||||
return {"external_trace_id": trace_id}
|
||||
return {}
|
||||
|
||||
|
||||
def get_trace_id_from_otel_context() -> Optional[str]:
|
||||
"""
|
||||
Retrieve the current trace ID from the active OpenTelemetry trace context.
|
||||
Returns None if:
|
||||
1. OpenTelemetry SDK is not installed or enabled.
|
||||
2. There is no active span or trace context.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import SpanContext, get_current_span
|
||||
from opentelemetry.trace.span import INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context: SpanContext = span.get_span_context()
|
||||
|
||||
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
|
||||
return None
|
||||
|
||||
trace_id_hex = f"{span_context.trace_id:032x}"
|
||||
return trace_id_hex
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def parse_traceparent_header(traceparent: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the `traceparent` header to extract the trace_id.
|
||||
|
||||
Expected format:
|
||||
'version-trace_id-span_id-flags'
|
||||
|
||||
Reference:
|
||||
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
||||
"""
|
||||
try:
|
||||
parts = traceparent.split("-")
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
import json_repair
|
||||
|
|
@ -11,6 +12,8 @@ from core.llm_generator.prompts import (
|
|||
CONVERSATION_TITLE_PROMPT,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
|
|
@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName
|
|||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
|
|
@ -388,3 +394,181 @@ class LLMGenerator:
|
|||
except Exception as e:
|
||||
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
) -> dict:
|
||||
app: App | None = db.session.query(App).filter(App.id == flow_id).first()
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=None,
|
||||
current=current,
|
||||
error_message="",
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_workflow(
|
||||
tenant_id: str,
|
||||
flow_id: str,
|
||||
node_id: str,
|
||||
current: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
ideal_output: str | None,
|
||||
) -> dict:
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app: App | None = db.session.query(App).filter(App.id == flow_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found.")
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found for the given app model.")
|
||||
last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
|
||||
try:
|
||||
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
|
||||
except Exception:
|
||||
try:
|
||||
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
|
||||
"type"
|
||||
]
|
||||
except Exception:
|
||||
node_type = "llm"
|
||||
|
||||
if not last_run: # Node is not executed yet
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=None,
|
||||
current=current,
|
||||
error_message="",
|
||||
instruction=instruction,
|
||||
node_type=node_type,
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
|
||||
if not raw_agent_log:
|
||||
return []
|
||||
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
|
||||
|
||||
def dict_of_event(event: AgentLogEvent) -> dict:
|
||||
return {
|
||||
"status": event.status,
|
||||
"error": event.error,
|
||||
"data": event.data,
|
||||
}
|
||||
|
||||
return [dict_of_event(event) for event in parsed]
|
||||
|
||||
last_run_dict = {
|
||||
"inputs": last_run.inputs_dict,
|
||||
"status": last_run.status,
|
||||
"error": last_run.error,
|
||||
"agent_log": agent_log_of(last_run),
|
||||
}
|
||||
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=last_run.error,
|
||||
instruction=instruction,
|
||||
node_type=last_run.node_type,
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __instruction_modify_common(
|
||||
tenant_id: str,
|
||||
model_config: dict,
|
||||
last_run: dict | None,
|
||||
current: str | None,
|
||||
error_message: str | None,
|
||||
instruction: str,
|
||||
node_type: str,
|
||||
ideal_output: str | None,
|
||||
) -> dict:
|
||||
LAST_RUN = "{{#last_run#}}"
|
||||
CURRENT = "{{#current#}}"
|
||||
ERROR_MESSAGE = "{{#error_message#}}"
|
||||
injected_instruction = instruction
|
||||
if LAST_RUN in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
|
||||
if CURRENT in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
|
||||
if ERROR_MESSAGE in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
match node_type:
|
||||
case "llm", "agent":
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
case "code":
|
||||
system_prompt = LLM_MODIFY_CODE_SYSTEM
|
||||
case _:
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
UserPromptMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"current": current,
|
||||
"last_run": last_run,
|
||||
"instruction": injected_instruction,
|
||||
"ideal_output": ideal_output,
|
||||
}
|
||||
)
|
||||
),
|
||||
]
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
|
|
|||
|
|
@ -309,3 +309,116 @@ eg:
|
|||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
|
||||
LLM_MODIFY_PROMPT_SYSTEM = """
|
||||
Both your input and output should be in JSON format.
|
||||
|
||||
! Below is the schema for input content !
|
||||
{
|
||||
"type": "object",
|
||||
"description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.",
|
||||
"properties": {
|
||||
"current": {
|
||||
"type": "string",
|
||||
"description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed."
|
||||
},
|
||||
"last_run": {
|
||||
"type": "object",
|
||||
"description": "The output result from the large language model after receiving the prompt.",
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "User's instruction to edit the current prompt"
|
||||
},
|
||||
"ideal_output": {
|
||||
"type": "string",
|
||||
"description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal."
|
||||
}
|
||||
}
|
||||
}
|
||||
! Above is the schema for input content !
|
||||
|
||||
! Below is the schema for output content !
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Your feedback to the user after they provide modification suggestions.",
|
||||
"properties": {
|
||||
"modified": {
|
||||
"type": "string",
|
||||
"description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"modified",
|
||||
"message"
|
||||
]
|
||||
}
|
||||
! Above is the schema for output content !
|
||||
|
||||
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
|
||||
""" # noqa: E501
|
||||
|
||||
LLM_MODIFY_CODE_SYSTEM = """
|
||||
Both your input and output should be in JSON format.
|
||||
|
||||
! Below is the schema for input content !
|
||||
{
|
||||
"type": "object",
|
||||
"description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.",
|
||||
"properties": {
|
||||
"current": {
|
||||
"type": "string",
|
||||
"description": "The code before modification."
|
||||
},
|
||||
"last_run": {
|
||||
"type": "object",
|
||||
"description": "The result of the code.",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "User's instruction to edit the current code"
|
||||
}
|
||||
}
|
||||
}
|
||||
! Above is the schema for input content !
|
||||
|
||||
! Below is the schema for output content !
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Your feedback to the user after they provide modification suggestions.",
|
||||
"properties": {
|
||||
"modified": {
|
||||
"type": "string",
|
||||
"description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"modified",
|
||||
"message"
|
||||
]
|
||||
}
|
||||
! Above is the schema for output content !
|
||||
|
||||
When you are modifying the code, you should remember:
|
||||
- Do not use print, this not work in dify sandbox.
|
||||
- Do not try dangerous call like deleting files. It's PROHIBITED.
|
||||
- Do not use any library that is not built-in in with Python.
|
||||
- Get inputs from the parameters of the function and have explicit type annotations.
|
||||
- Write proper imports at the top of the code.
|
||||
- Use return statement to return the result.
|
||||
- You should return a `dict`.
|
||||
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
|
||||
""" # noqa: E501
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}.
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ from core.mcp.types import (
|
|||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final
|
|||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
|
|
@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
|||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
|
|
@ -114,7 +110,7 @@ class SSETransport:
|
|||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
|
|
@ -130,7 +126,7 @@ class SSETransport:
|
|||
case _:
|
||||
logger.warning("Unknown SSE event: %s", sse.event)
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
|
|
@ -225,7 +221,7 @@ class SSETransport:
|
|||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
event_source: EventSource,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,13 +16,14 @@ from extensions.ext_database import db
|
|||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import httpx
|
||||
import httpx_sse
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
|
|
@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client(
|
|||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
to prevent SSRF attacks when connecting to external endpoints. It returns
|
||||
a context manager that yields an EventSource object for SSE streaming.
|
||||
|
||||
The function handles HTTP client creation and cleanup automatically, but
|
||||
also accepts a pre-configured client via kwargs.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
url (str): The SSE endpoint URL to connect to
|
||||
**kwargs: Additional arguments passed to the SSE connection, including:
|
||||
- client (httpx.Client, optional): Pre-configured HTTP client.
|
||||
If not provided, one will be created with SSRF protection.
|
||||
- method (str, optional): HTTP method to use, defaults to "GET"
|
||||
- headers (dict, optional): HTTP headers to include in the request
|
||||
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
|
||||
object for SSE streaming. The EventSource provides access to server-sent events.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
print(sse.event, sse.data)
|
||||
```
|
||||
|
||||
Note:
|
||||
If a client is not provided in kwargs, one will be automatically created
|
||||
with SSRF protection based on the application's configuration. If an
|
||||
exception occurs during connection, any automatically created client
|
||||
will be cleaned up automatically.
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
|
|
@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs):
|
|||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
||||
def create_mcp_error_response(
|
||||
request_id: int | str | None, code: int, message: str, data=None
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
|
|
|
|||
|
|
@ -151,12 +151,9 @@ def jsonable_encoder(
|
|||
return format(obj, "f")
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
for key, value in obj.items():
|
||||
if (
|
||||
(not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa")))
|
||||
and (value is not None or not exclude_none)
|
||||
and key in allowed_keys
|
||||
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
|
||||
value is not None or not exclude_none
|
||||
):
|
||||
encoded_key = jsonable_encoder(
|
||||
key,
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ from collections.abc import Sequence
|
|||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_string_to_id,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
|
@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
|
|
@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
|
|
@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
|
|
@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
message_id = trace_info.message_id
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
|
|
@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
|
|
@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
|
|
@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
|
|
@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
|
||||
def add_workflow_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
|
||||
):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
|
|
@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
|
|
@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
|
|
@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
|
|
@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
|||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
|
@ -166,6 +167,16 @@ class SpanBuilder:
|
|||
return span
|
||||
|
||||
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
placeholder_span_id = 0x0000000000000000
|
||||
trace_id = int(trace_id_str, 16)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
|
||||
)
|
||||
|
||||
return Link(span_context)
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
|
|
|
|||
|
|
@ -523,7 +523,7 @@ class ProviderManager:
|
|||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
|
||||
new_provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -134,13 +134,13 @@ class Jieba(BaseKeyword):
|
|||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
|
||||
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
|
|
@ -156,12 +156,11 @@ class Jieba(BaseKeyword):
|
|||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(
|
||||
{
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
|
||||
},
|
||||
cls=SetEncoder,
|
||||
}
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
|
@ -252,8 +251,13 @@ class Jieba(BaseKeyword):
|
|||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
return super().default(obj)
|
||||
def set_orjson_default(obj: Any) -> Any:
|
||||
"""Default function for orjson serialization of set types"""
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def dumps_with_sets(obj: Any) -> str:
|
||||
"""JSON dumps with set support using orjson"""
|
||||
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import math
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient # type: ignore
|
||||
from sqlalchemy import JSON, Column, String, func
|
||||
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
|
||||
from sqlalchemy import JSON, Column, String
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector):
|
|||
)
|
||||
try:
|
||||
if self._hybrid_search_enabled:
|
||||
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name}
|
||||
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""")
|
||||
self._client.create_fts_idx_with_fts_index_param(
|
||||
table_name=self._collection_name,
|
||||
fts_idx_param=FtsIndexParam(
|
||||
index_name="fulltext_index_for_col_text",
|
||||
field_names=["text"],
|
||||
parser_type=FtsParser.IK,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
|
||||
+ "to support fulltext index and vector index in the same table",
|
||||
e,
|
||||
)
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
|
|
@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector):
|
|||
vec_column_name="vector",
|
||||
vec_data=query_vector,
|
||||
topk=topk,
|
||||
distance_func=func.l2_distance,
|
||||
distance_func=l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=_where_clause,
|
||||
|
|
|
|||
|
|
@ -331,6 +331,12 @@ class QdrantVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score_threshold >= 1:
|
||||
# return empty list because some versions of qdrant may response with 400 bad request,
|
||||
# and at the same time, the score_threshold with value 1 may be valid for other vector stores
|
||||
return []
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
|
|
@ -355,7 +361,7 @@ class QdrantVector(BaseVector):
|
|||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
|
|
@ -363,7 +369,6 @@ class QdrantVector(BaseVector):
|
|||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
|
||||
|
|
@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum):
|
|||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
# 'api_key' deprecated in PR #21656
|
||||
# normalize & tiny alias for backward compatibility
|
||||
v = (value or "").strip().lower()
|
||||
if v == "api_key":
|
||||
v = cls.API_KEY_HEADER.value
|
||||
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
if mode.value == v:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
valid = ", ".join(m.value for m in cls)
|
||||
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
|
|
@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
|||
return selectors
|
||||
|
||||
|
||||
class SegmentJSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
elif isinstance(o, SegmentGroup):
|
||||
return [self.default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
else:
|
||||
super().default(o)
|
||||
def segment_orjson_default(o: Any) -> Any:
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
elif isinstance(o, SegmentGroup):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""JSON dumps with segment support using orjson"""
|
||||
option = orjson.OPT_NON_STR_KEYS
|
||||
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -276,17 +276,26 @@ class Executor:
|
|||
encoded_credentials = credentials
|
||||
headers[authorization.config.header] = f"Basic {encoded_credentials}"
|
||||
elif self.auth.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key or ""
|
||||
if authorization.config.header and authorization.config.api_key:
|
||||
headers[authorization.config.header] = authorization.config.api_key
|
||||
|
||||
# Handle Content-Type for multipart/form-data requests
|
||||
# Fix for issue #22880: Missing boundary when using multipart/form-data
|
||||
# Fix for issue #23829: Missing boundary when using multipart/form-data
|
||||
body = self.node_data.body
|
||||
if body and body.type == "form-data":
|
||||
# For multipart/form-data with files, let httpx handle the boundary automatically
|
||||
# by not setting Content-Type header when files are present
|
||||
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
# Only set Content-Type when there are no actual files
|
||||
# This ensures httpx generates the correct boundary
|
||||
# For multipart/form-data with files (including placeholder files),
|
||||
# remove any manually set Content-Type header to let httpx handle
|
||||
# For multipart/form-data, if any files are present (including placeholder files),
|
||||
# we must remove any manually set Content-Type header. This is because httpx needs to
|
||||
# automatically set the Content-Type and boundary for multipart encoding whenever files
|
||||
# are included, even if they are placeholders, to avoid boundary issues and ensure correct
|
||||
# file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the
|
||||
# boundary, resulting in invalid requests.
|
||||
if self.files:
|
||||
# Remove Content-Type if it was manually set to avoid boundary issues
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
|
||||
else:
|
||||
# No files at all, set Content-Type manually
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = "multipart/form-data"
|
||||
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
|
|
|
|||
|
|
@ -318,33 +318,6 @@ class ToolNode(BaseNode):
|
|||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": message.message.text,
|
||||
}
|
||||
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import click
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from events.document_index_event import document_index_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import ssl
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytz
|
||||
from celery import Celery, Task # type: ignore
|
||||
|
|
@ -8,6 +10,40 @@ from configs import dify_config
|
|||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def _get_celery_ssl_options() -> Optional[dict[str, Any]]:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
return None
|
||||
|
||||
# Check if Celery is actually using Redis
|
||||
broker_is_redis = dify_config.CELERY_BROKER_URL and (
|
||||
dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://")
|
||||
)
|
||||
|
||||
if not broker_is_redis:
|
||||
return None
|
||||
|
||||
# Map certificate requirement strings to SSL constants
|
||||
cert_reqs_map = {
|
||||
"CERT_NONE": ssl.CERT_NONE,
|
||||
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||
}
|
||||
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return ssl_options
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
|
|
@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery:
|
|||
task_ignore_result=True,
|
||||
)
|
||||
|
||||
# Add SSL options to the Celery configuration
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": None,
|
||||
"ssl_ca_certs": None,
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
}
|
||||
|
||||
celery_app.conf.update(
|
||||
result_backend=dify_config.CELERY_RESULT_BACKEND,
|
||||
broker_transport_options=broker_transport_options,
|
||||
|
|
@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery:
|
|||
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
||||
)
|
||||
|
||||
if dify_config.BROKER_USE_SSL:
|
||||
# Apply SSL configuration if enabled
|
||||
ssl_options = _get_celery_ssl_options()
|
||||
if ssl_options:
|
||||
celery_app.conf.update(
|
||||
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
|
||||
broker_use_ssl=ssl_options,
|
||||
# Also apply SSL to the backend if it's Redis
|
||||
redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None,
|
||||
)
|
||||
|
||||
if dify_config.LOG_FILE:
|
||||
|
|
@ -113,7 +145,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
||||
),
|
||||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
beat_schedule["check_upgradable_plugin_task"] = {
|
||||
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from dify_app import DifyApp
|
|||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
|
|
@ -42,6 +43,7 @@ def init_app(app: DifyApp):
|
|||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
setup_system_tool_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
from flask_orjson import OrjsonProvider
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
"""Initialize Flask-Orjson extension for faster JSON serialization"""
|
||||
app.json = OrjsonProvider(app)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import functools
|
||||
import logging
|
||||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
|
@ -116,76 +117,132 @@ class RedisClientWrapper:
|
|||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
global redis_client
|
||||
connection_class: type[Union[Connection, SSLConnection]] = Connection
|
||||
if dify_config.REDIS_USE_SSL:
|
||||
connection_class = SSLConnection
|
||||
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
||||
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
||||
if resp_protocol >= 3:
|
||||
clientside_cache_config = CacheConfig()
|
||||
else:
|
||||
raise ValueError("Client side cache is only supported in RESP3")
|
||||
else:
|
||||
clientside_cache_config = None
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
"""Get SSL configuration for Redis connection."""
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
return Connection, {}
|
||||
|
||||
redis_params: dict[str, Any] = {
|
||||
cert_reqs_map = {
|
||||
"CERT_NONE": ssl.CERT_NONE,
|
||||
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||
}
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_kwargs = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return SSLConnection, ssl_kwargs
|
||||
|
||||
|
||||
def _get_cache_configuration() -> CacheConfig | None:
|
||||
"""Get client-side cache configuration if enabled."""
|
||||
if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
||||
return None
|
||||
|
||||
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
||||
if resp_protocol < 3:
|
||||
raise ValueError("Client side cache is only supported in RESP3")
|
||||
|
||||
return CacheConfig()
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
"""Get base Redis connection parameters."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
"db": dify_config.REDIS_DB,
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": resp_protocol,
|
||||
"cache_config": clientside_cache_config,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
}
|
||||
|
||||
if dify_config.REDIS_USE_SENTINEL:
|
||||
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
|
||||
assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, (
|
||||
"REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True"
|
||||
)
|
||||
sentinel_hosts = [
|
||||
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
|
||||
]
|
||||
sentinel = Sentinel(
|
||||
sentinel_hosts,
|
||||
sentinel_kwargs={
|
||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||
},
|
||||
)
|
||||
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
redis_client.initialize(master)
|
||||
elif dify_config.REDIS_USE_CLUSTERS:
|
||||
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
|
||||
nodes = [
|
||||
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
redis_client.initialize(
|
||||
RedisCluster(
|
||||
startup_nodes=nodes,
|
||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
protocol=resp_protocol,
|
||||
cache_config=clientside_cache_config,
|
||||
)
|
||||
)
|
||||
else:
|
||||
redis_params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
"protocol": resp_protocol,
|
||||
"cache_config": clientside_cache_config,
|
||||
}
|
||||
)
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
redis_client.initialize(redis.Redis(connection_pool=pool))
|
||||
|
||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis client using Sentinel configuration."""
|
||||
if not dify_config.REDIS_SENTINELS:
|
||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||
|
||||
if not dify_config.REDIS_SENTINEL_SERVICE_NAME:
|
||||
raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True")
|
||||
|
||||
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
|
||||
|
||||
sentinel = Sentinel(
|
||||
sentinel_hosts,
|
||||
sentinel_kwargs={
|
||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||
},
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
return master
|
||||
|
||||
|
||||
def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis cluster client."""
|
||||
if not dify_config.REDIS_CLUSTERS:
|
||||
raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True")
|
||||
|
||||
nodes = [
|
||||
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
|
||||
cluster: RedisCluster = RedisCluster(
|
||||
startup_nodes=nodes,
|
||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
)
|
||||
return cluster
|
||||
|
||||
|
||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
redis_params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
)
|
||||
|
||||
if ssl_kwargs:
|
||||
redis_params.update(ssl_kwargs)
|
||||
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
client: redis.Redis = redis.Redis(connection_pool=pool)
|
||||
return client
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize Redis client and attach it to the app."""
|
||||
global redis_client
|
||||
|
||||
# Determine Redis mode and create appropriate client
|
||||
if dify_config.REDIS_USE_SENTINEL:
|
||||
redis_params = _get_base_redis_params()
|
||||
client = _create_sentinel_client(redis_params)
|
||||
elif dify_config.REDIS_USE_CLUSTERS:
|
||||
client = _create_cluster_client()
|
||||
else:
|
||||
redis_params = _get_base_redis_params()
|
||||
client = _create_standalone_client(redis_params)
|
||||
|
||||
# Initialize the wrapper and attach to app
|
||||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -248,6 +248,8 @@ def _get_remote_file_info(url: str):
|
|||
|
||||
# Initialize mime_type from filename as fallback
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if mime_type is None:
|
||||
mime_type = ""
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
resp = cast(httpx.Response, resp)
|
||||
|
|
@ -256,7 +258,12 @@ def _get_remote_file_info(url: str):
|
|||
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||
# Re-guess mime_type from updated filename
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if mime_type is None:
|
||||
mime_type = ""
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
# Fallback to Content-Type header if mime_type is still empty
|
||||
if not mime_type:
|
||||
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
|
|
|||
|
|
@ -1153,7 +1153,7 @@ class WorkflowDraftVariable(Base):
|
|||
value: The Segment object to store as the variable's value.
|
||||
"""
|
||||
self.__value = value
|
||||
self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder)
|
||||
self.value = variable_utils.dumps_with_segments(value)
|
||||
self.value_type = value.value_type
|
||||
|
||||
def get_node_id(self) -> str | None:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ check_untyped_defs = True
|
|||
cache_fine_grained = True
|
||||
sqlite_cache = True
|
||||
exclude = (?x)(
|
||||
core/model_runtime/model_providers/
|
||||
| tests/
|
||||
tests/
|
||||
| migrations/
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ dependencies = [
|
|||
"flask-cors~=6.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-restful~=0.3.10",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=24.11.1",
|
||||
|
|
@ -204,7 +205,7 @@ vdb = [
|
|||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==1.3.1",
|
||||
"pyobvector~=0.1.6",
|
||||
"pyobvector~=0.2.15",
|
||||
"qdrant-client==1.9.0",
|
||||
"tablestore==6.2.0",
|
||||
"tcvectordb~=1.6.4",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import datetime
|
||||
import time
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
|
|
@ -14,168 +15,140 @@ from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Documen
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class CleanupConfig(TypedDict):
|
||||
clean_day: datetime.datetime
|
||||
plan_filter: Optional[str]
|
||||
add_logs: bool
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
def clean_unused_datasets_task():
|
||||
click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
|
||||
plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING
|
||||
plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING
|
||||
start_at = time.perf_counter()
|
||||
plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting)
|
||||
plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting)
|
||||
while True:
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > plan_sandbox_clean_day,
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at < plan_sandbox_clean_day,
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
# Define cleanup configurations
|
||||
cleanup_configs: list[CleanupConfig] = [
|
||||
{
|
||||
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING),
|
||||
"plan_filter": None,
|
||||
"add_logs": True,
|
||||
},
|
||||
{
|
||||
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
|
||||
"plan_filter": "sandbox",
|
||||
"add_logs": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Main query with join and filter
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||
.where(
|
||||
Dataset.created_at < plan_sandbox_clean_day,
|
||||
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
|
||||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||
)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
for config in cleanup_configs:
|
||||
clean_day = config["clean_day"]
|
||||
plan_filter = config["plan_filter"]
|
||||
add_logs = config["add_logs"]
|
||||
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
for dataset in datasets:
|
||||
dataset_query = (
|
||||
db.session.query(DatasetQuery)
|
||||
.where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
|
||||
.all()
|
||||
)
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
try:
|
||||
# add auto disable log
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
while True:
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > clean_day,
|
||||
)
|
||||
for document in documents:
|
||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
db.session.add(dataset_auto_disable_log)
|
||||
# remove index
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
# update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
while True:
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > plan_pro_clean_day,
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at < plan_pro_clean_day,
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at < clean_day,
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with join and filter
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||
.where(
|
||||
Dataset.created_at < plan_pro_clean_day,
|
||||
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
|
||||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||
# Main query with join and filter
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||
.where(
|
||||
Dataset.created_at < clean_day,
|
||||
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
|
||||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||
)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
for dataset in datasets:
|
||||
dataset_query = (
|
||||
db.session.query(DatasetQuery)
|
||||
.where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
|
||||
.all()
|
||||
)
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
try:
|
||||
features_cache_key = f"features:{dataset.tenant_id}"
|
||||
plan_cache = redis_client.get(features_cache_key)
|
||||
if plan_cache is None:
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
|
||||
plan = features.billing.subscription.plan
|
||||
else:
|
||||
plan = plan_cache.decode()
|
||||
if plan == "sandbox":
|
||||
# remove index
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_query = (
|
||||
db.session.query(DatasetQuery)
|
||||
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
try:
|
||||
should_clean = True
|
||||
|
||||
# Check plan filter if specified
|
||||
if plan_filter:
|
||||
features_cache_key = f"features:{dataset.tenant_id}"
|
||||
plan_cache = redis_client.get(features_cache_key)
|
||||
if plan_cache is None:
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
|
||||
plan = features.billing.subscription.plan
|
||||
else:
|
||||
plan = plan_cache.decode()
|
||||
should_clean = plan == plan_filter
|
||||
|
||||
if should_clean:
|
||||
# Add auto disable log if required
|
||||
if add_logs:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for document in documents:
|
||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
db.session.add(dataset_auto_disable_log)
|
||||
|
||||
# Remove index
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
# Update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
|
||||
{Document.enabled: False}
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
|
||||
# update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -103,10 +103,10 @@ class ConversationService:
|
|||
@classmethod
|
||||
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if sort_direction == desc:
|
||||
if sort_direction is desc:
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
else:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
@classmethod
|
||||
def rename(
|
||||
|
|
@ -147,7 +147,7 @@ class ConversationService:
|
|||
app_model.tenant_id, message.query, conversation.id, app_model.id
|
||||
)
|
||||
conversation.name = name
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.session.commit()
|
||||
|
|
@ -277,6 +277,11 @@ class ConversationService:
|
|||
|
||||
# Validate that the new value type matches the expected variable type
|
||||
expected_type = SegmentType(current_variable.value_type)
|
||||
|
||||
# There is showing number in web ui but int in db
|
||||
if expected_type == SegmentType.INTEGER:
|
||||
expected_type = SegmentType.NUMBER
|
||||
|
||||
if not expected_type.is_valid(new_value):
|
||||
inferred_type = SegmentType.infer_segment_type(new_value)
|
||||
raise ConversationVariableTypeMismatchError(
|
||||
|
|
|
|||
|
|
@ -47,7 +47,9 @@ class OAuthProxyService(BasePluginClient):
|
|||
if not context_id:
|
||||
raise ValueError("context_id is required")
|
||||
# get data from redis
|
||||
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
|
||||
key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}"
|
||||
data = redis_client.get(key)
|
||||
if not data:
|
||||
raise ValueError("context_id is invalid")
|
||||
redis_client.delete(key)
|
||||
return json.loads(data)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,11 @@ from models import (
|
|||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
|
||||
from models.workflow import (
|
||||
ConversationVariable,
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
|
|
@ -62,6 +66,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
|||
_delete_end_users(tenant_id, app_id)
|
||||
_delete_trace_app_configs(tenant_id, app_id)
|
||||
_delete_conversation_variables(app_id=app_id)
|
||||
_delete_draft_variables(app_id)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
|
@ -91,7 +96,12 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
|||
def del_site(site_id: str):
|
||||
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
|
||||
_delete_records(
|
||||
"""select id from sites where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_site,
|
||||
"site",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
|
|
@ -111,7 +121,10 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
|||
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token"
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_api_token,
|
||||
"api token",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -273,7 +286,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
|||
db.session.query(Message).where(Message.id == message_id).delete()
|
||||
|
||||
_delete_records(
|
||||
"""select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message"
|
||||
"""select id from messages where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_message,
|
||||
"message",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -329,6 +345,56 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_draft_variables(app_id: str):
|
||||
"""Delete all workflow draft variables for an app in batches."""
|
||||
return delete_draft_variables_batch(app_id, batch_size=1000)
|
||||
|
||||
|
||||
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
"""
|
||||
Delete draft variables for an app in batches.
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app whose draft variables should be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted
|
||||
"""
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
# Get a batch of draft variable IDs
|
||||
query_sql = """
|
||||
SELECT id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
|
||||
draft_var_ids = [row[0] for row in result]
|
||||
if not draft_var_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_sql = """
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
"""
|
||||
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
|
||||
batch_deleted = deleted_result.rowcount
|
||||
total_deleted += batch_deleted
|
||||
|
||||
logging.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
|
||||
|
||||
logging.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
|
||||
return total_deleted
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from models import Tenant, db
|
||||
from models.model import App
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_and_tenant(flask_req_ctx):
|
||||
tenant_id = uuid.uuid4()
|
||||
tenant = Tenant(
|
||||
id=tenant_id,
|
||||
name="test_tenant",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant_id, # Now tenant.id will have a value
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
yield (tenant, app)
|
||||
|
||||
# Cleanup with proper error handling
|
||||
db.session.delete(app)
|
||||
db.session.delete(tenant)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesIntegration:
|
||||
@pytest.fixture
|
||||
def setup_test_data(self, app_and_tenant):
|
||||
"""Create test data with apps and draft variables."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create a second app for testing
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.commit()
|
||||
|
||||
# Create draft variables for both apps
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
|
||||
# Commit all the variables to the database
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app1": app,
|
||||
"app2": app2,
|
||||
"tenant": tenant,
|
||||
"variables_app1": variables_app1,
|
||||
"variables_app2": variables_app2,
|
||||
}
|
||||
|
||||
# Cleanup - refresh session and check if objects still exist
|
||||
db.session.rollback() # Clear any pending changes
|
||||
|
||||
# Clean up remaining variables
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
# Clean up app2
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
|
||||
"""Test that batch deletion only removes variables for the specified app."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
app2_id = data["app2"].id
|
||||
|
||||
# Verify initial state
|
||||
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_before == 5
|
||||
assert app2_vars_before == 5
|
||||
|
||||
# Delete app1 variables
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
|
||||
assert app1_vars_after == 0 # All app1 variables deleted
|
||||
assert app2_vars_after == 5 # App2 variables unchanged
|
||||
|
||||
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
|
||||
"""Test batch deletion with small batch size processes all records."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
|
||||
|
||||
assert deleted_count == 5
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
|
||||
"""Test that deleting variables for nonexistent app returns 0."""
|
||||
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
|
||||
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
|
||||
"""Test that _delete_draft_variables wrapper function works correctly."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Verify initial state
|
||||
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_before == 5
|
||||
|
||||
# Call wrapper function
|
||||
deleted_count = _delete_draft_variables(app1_id)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_after == 0
|
||||
|
||||
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
|
||||
"""Test batch deletion with larger dataset to verify batching logic."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create many draft variables
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
variables.append(var)
|
||||
variable_ids = [i.id for i in variables]
|
||||
|
||||
# Commit the variables to the database
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
|
||||
|
||||
assert deleted_count == 25
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
finally:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.id.in_(variable_ids),
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
|
|
@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest):
|
|||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
# only test for qdrant, may not work on other vector stores
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(
|
||||
query_vector=self.example_embedding, score_threshold=1
|
||||
)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis):
|
||||
QdrantVectorTest().run_all_tests()
|
||||
|
|
|
|||
|
|
@ -160,6 +160,177 @@ def test_custom_authorization_header(setup_http_mock):
|
|||
|
||||
assert "?A=b" in data
|
||||
assert "X-Header: 123" in data
|
||||
# Custom authorization header should be set (may be masked)
|
||||
assert "X-Auth:" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create node data with custom auth and empty api_key
|
||||
node_data = HttpRequestNodeData(
|
||||
title="http",
|
||||
desc="",
|
||||
url="http://example.com",
|
||||
method="get",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={
|
||||
"type": "custom",
|
||||
"api_key": "", # Empty api_key
|
||||
"header": "X-Custom-Auth",
|
||||
},
|
||||
),
|
||||
headers="",
|
||||
params="",
|
||||
body=None,
|
||||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = Executor(
|
||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
||||
)
|
||||
|
||||
# Get assembled headers
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
# When api_key is empty, the custom header should NOT be set
|
||||
assert "X-Custom-Auth" not in headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
"""
|
||||
Test that when switching from custom to bearer authorization,
|
||||
the custom header settings don't interfere with bearer token.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": "test-token",
|
||||
"header": "", # Empty header - should default to Authorization
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# In bearer mode, should use Authorization header (value is masked with *)
|
||||
assert "Authorization: " in data
|
||||
# Should contain masked Bearer token
|
||||
assert "*" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
"""
|
||||
Test that when switching from custom to basic authorization,
|
||||
the custom header settings don't interfere with basic auth.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "user:pass",
|
||||
"header": "", # Empty header - should default to Authorization
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# In basic mode, should use Authorization header (value is masked with *)
|
||||
assert "Authorization: " in data
|
||||
# Should contain masked Basic credentials
|
||||
assert "*" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
"""
|
||||
Test that custom authorization doesn't set header when api_key is empty.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "custom",
|
||||
"api_key": "", # Empty api_key
|
||||
"header": "X-Custom-Auth",
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Custom header should NOT be set when api_key is empty
|
||||
assert "X-Custom-Auth:" not in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
|
@ -239,6 +410,7 @@ def test_json(setup_http_mock):
|
|||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_x_www_form_urlencoded(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
|
|
@ -285,6 +457,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
|||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_form_data(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
|
|
@ -334,6 +507,7 @@ def test_form_data(setup_http_mock):
|
|||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_none_data(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
|
|
@ -366,6 +540,7 @@ def test_none_data(setup_http_mock):
|
|||
assert "123123123" not in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_mock_404(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
|
|
@ -394,6 +569,7 @@ def test_mock_404(setup_http_mock):
|
|||
assert "Not Found" in resp.get("body", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_multi_colons_parse(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
|
|
|
|||
|
|
@ -0,0 +1,885 @@
|
|||
import copy
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from models.model import AppMode
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class TestAdvancedPromptTemplateService:
|
||||
"""Integration tests for AdvancedPromptTemplateService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
# This service doesn't have external dependencies, but we keep the pattern
|
||||
# for consistency with other test files
|
||||
return {}
|
||||
|
||||
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful prompt generation for Baichuan model.
|
||||
|
||||
This test verifies:
|
||||
- Proper prompt generation for Baichuan models
|
||||
- Correct model detection logic
|
||||
- Appropriate prompt template selection
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test data for Baichuan model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is included for Baichuan model
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful prompt generation for common models.
|
||||
|
||||
This test verifies:
|
||||
- Proper prompt generation for non-Baichuan models
|
||||
- Correct model detection logic
|
||||
- Appropriate prompt template selection
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test data for common model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is included for common model
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_case_insensitive_baichuan_detection(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan model detection is case insensitive.
|
||||
|
||||
This test verifies:
|
||||
- Model name detection works regardless of case
|
||||
- Proper prompt template selection for different case variations
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different case variations
|
||||
test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"]
|
||||
|
||||
for model_name in test_cases:
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": model_name,
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify Baichuan template is used
|
||||
assert result is not None
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for chat app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for chat app + completion mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "conversation_histories_role" in result["completion_prompt_config"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test common prompt generation for chat app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for chat app + chat mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for completion app + completion mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for completion app + chat mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test common prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Correct handling when has_context is "false"
|
||||
- Context is not included in prompt
|
||||
- Template structure remains intact
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is NOT included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT not in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported app mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported app modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_common_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported model mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported model modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test completion prompt generation with context.
|
||||
|
||||
This test verifies:
|
||||
- Proper context integration in completion prompts
|
||||
- Template structure preservation
|
||||
- Context placement at the beginning
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is prepended to original text
|
||||
result_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert result_text.startswith(CONTEXT)
|
||||
assert original_text in result_text
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_completion_prompt_without_context(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test completion prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Original template is preserved when no context
|
||||
- No modification to prompt text
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify original text is unchanged
|
||||
result_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert result_text == original_text
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test chat prompt generation with context.
|
||||
|
||||
This test verifies:
|
||||
- Proper context integration in chat prompts
|
||||
- Template structure preservation
|
||||
- Context placement at the beginning of first message
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is prepended to original text
|
||||
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert result_text.startswith(CONTEXT)
|
||||
assert original_text in result_text
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test chat prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Original template is preserved when no context
|
||||
- No modification to prompt text
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify original text is unchanged
|
||||
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert result_text == original_text
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for chat app + completion mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "conversation_histories_role" in result["completion_prompt_config"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for chat app + chat mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for completion app + completion mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for completion app + chat mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test Baichuan prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Correct handling when has_context is "false"
|
||||
- Baichuan context is not included in prompt
|
||||
- Template structure remains intact
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify Baichuan context is NOT included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT not in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported app mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported app modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported model mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported model modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_prompt_all_app_modes_common_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with common model.
|
||||
|
||||
This test verifies:
|
||||
- All app modes work correctly with common models
|
||||
- Proper template selection for each combination
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
for model_mode in model_modes:
|
||||
args = {
|
||||
"app_mode": app_mode,
|
||||
"model_mode": model_mode,
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify result is not empty
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_all_app_modes_baichuan_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with Baichuan model.
|
||||
|
||||
This test verifies:
|
||||
- All app modes work correctly with Baichuan models
|
||||
- Proper template selection for each combination
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
for model_mode in model_modes:
|
||||
args = {
|
||||
"app_mode": app_mode,
|
||||
"model_mode": model_mode,
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify result is not empty
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test prompt generation with edge cases.
|
||||
|
||||
This test verifies:
|
||||
- Handling of edge case inputs
|
||||
- Proper error handling
|
||||
- Consistent behavior with unusual inputs
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "",
|
||||
},
|
||||
]
|
||||
|
||||
for args in edge_cases:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify method handles edge cases gracefully
|
||||
# Should either return a valid result or empty dict, but not crash
|
||||
assert result is not None
|
||||
|
||||
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that original templates are not modified.
|
||||
|
||||
This test verifies:
|
||||
- Original template constants are not modified
|
||||
- Deep copy is used properly
|
||||
- Template immutability is maintained
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Store original templates
|
||||
original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify original templates are unchanged
|
||||
assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that original Baichuan templates are not modified.
|
||||
|
||||
This test verifies:
|
||||
- Original Baichuan template constants are not modified
|
||||
- Deep copy is used properly
|
||||
- Template immutability is maintained
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Store original templates
|
||||
original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify original templates are unchanged
|
||||
assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test consistency of context integration across different scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Context is always prepended correctly
|
||||
- Context integration is consistent across different templates
|
||||
- No context duplication or corruption
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
]
|
||||
|
||||
for args in test_scenarios:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify context integration is consistent
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
# Check that context is properly integrated
|
||||
if "completion_prompt_config" in result:
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert prompt_text.startswith(CONTEXT)
|
||||
elif "chat_prompt_config" in result:
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert prompt_text.startswith(CONTEXT)
|
||||
|
||||
def test_baichuan_context_integration_consistency(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test consistency of Baichuan context integration across different scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Baichuan context is always prepended correctly
|
||||
- Context integration is consistent across different templates
|
||||
- No context duplication or corruption
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
]
|
||||
|
||||
for args in test_scenarios:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify context integration is consistent
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
# Check that Baichuan context is properly integrated
|
||||
if "completion_prompt_config" in result:
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert prompt_text.startswith(BAICHUAN_CONTEXT)
|
||||
elif "chat_prompt_config" in result:
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert prompt_text.startswith(BAICHUAN_CONTEXT)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,474 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.model import Account, Tenant
|
||||
from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class TestModelLoadBalancingService:
|
||||
"""Integration tests for ModelLoadBalancingService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
|
||||
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
|
||||
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
|
||||
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
|
||||
# Mock provider configuration
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.provider.provider = "openai"
|
||||
mock_provider_config.custom_configuration.provider = None
|
||||
|
||||
# Mock provider model setting
|
||||
mock_provider_model_setting = MagicMock()
|
||||
mock_provider_model_setting.load_balancing_enabled = False
|
||||
|
||||
mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting
|
||||
|
||||
# Mock provider configurations dict
|
||||
mock_provider_configs = {"openai": mock_provider_config}
|
||||
mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs
|
||||
|
||||
# Mock LBModelManager
|
||||
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
|
||||
|
||||
# Mock ModelProviderFactory
|
||||
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
|
||||
|
||||
# Mock credential schemas
|
||||
mock_credential_schema = MagicMock()
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock provider configuration methods
|
||||
mock_provider_config.extract_secret_variables.return_value = []
|
||||
mock_provider_config.obfuscated_credentials.return_value = {}
|
||||
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
|
||||
|
||||
yield {
|
||||
"provider_manager": mock_provider_manager,
|
||||
"lb_model_manager": mock_lb_model_manager,
|
||||
"model_provider_factory": mock_model_provider_factory,
|
||||
"encrypter": mock_encrypter,
|
||||
"provider_config": mock_provider_config,
|
||||
"provider_model_setting": mock_provider_model_setting,
|
||||
"credential_schema": mock_credential_schema,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_provider_and_setting(
|
||||
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test provider and provider model setting.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
tenant_id: Tenant ID for the provider
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (provider, provider_model_setting) - Created provider and setting instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create provider
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
provider_type="custom",
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Create provider model setting
|
||||
provider_model_setting = ProviderModelSetting(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(provider_model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return provider, provider_model_setting
|
||||
|
||||
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful model load balancing enablement.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful enablement of model load balancing
|
||||
- Correct method calls to provider configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for enable method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.enable_model_load_balancing = MagicMock()
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
service.enable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
mock_provider_config.enable_model_load_balancing.assert_called_once()
|
||||
call_args = mock_provider_config.enable_model_load_balancing.call_args
|
||||
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful model load balancing disablement.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful disablement of model load balancing
|
||||
- Correct method calls to provider configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for disable method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.disable_model_load_balancing = MagicMock()
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
service.disable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
mock_provider_config.disable_model_load_balancing.assert_called_once()
|
||||
call_args = mock_provider_config.disable_model_load_balancing.call_args
|
||||
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_enable_model_load_balancing_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent provider
|
||||
- Correct exception type and message
|
||||
- No database state changes
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to return empty provider configurations
|
||||
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
mock_provider_manager_instance.get_configurations.return_value = {}
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
service = ModelLoadBalancingService()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.enable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Verify correct error message
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of load balancing configurations.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful database query for load balancing configs
|
||||
- Correct return format and data structure
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
|
||||
# Verify the config was created
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Setup mocks for get_load_balancing_configs method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
|
||||
mock_provider_model_setting.load_balancing_enabled = True
|
||||
|
||||
# Mock credential schema methods
|
||||
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
|
||||
|
||||
# Mock _get_credential_schema method
|
||||
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
|
||||
|
||||
# Mock extract_secret_variables method
|
||||
mock_provider_config.extract_secret_variables.return_value = []
|
||||
|
||||
# Mock obfuscated_credentials method
|
||||
mock_provider_config.obfuscated_credentials.return_value = {}
|
||||
|
||||
# Mock LBModelManager.get_config_in_cooldown_and_ttl
|
||||
mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"]
|
||||
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["id"] == load_balancing_config.id
|
||||
assert configs[0]["name"] == "config1"
|
||||
assert configs[0]["enabled"] is True
|
||||
assert configs[0]["in_cooldown"] is False
|
||||
assert configs[0]["ttl"] == 0
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
def test_get_load_balancing_configs_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist in get_load_balancing_configs.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent provider
|
||||
- Correct exception type and message
|
||||
- No database state changes
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to return empty provider configurations
|
||||
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
mock_provider_manager_instance.get_configurations.return_value = {}
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
service = ModelLoadBalancingService()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Verify correct error message
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_with_inherit_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test load balancing configs retrieval with inherit configuration.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of inherit configuration
|
||||
- Correct ordering of configurations
|
||||
- Inherit config initialization when needed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for inherit config scenario
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config
|
||||
|
||||
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
|
||||
mock_provider_model_setting.load_balancing_enabled = True
|
||||
|
||||
# Mock credential schema methods
|
||||
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 2 # inherit config + existing config
|
||||
|
||||
# First config should be inherit config
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[0]["enabled"] is True
|
||||
|
||||
# Second config should be the existing config
|
||||
assert configs[1]["id"] == load_balancing_config.id
|
||||
assert configs[1]["name"] == "config1"
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Verify inherit config was created in database
|
||||
inherit_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all()
|
||||
)
|
||||
assert len(inherit_configs) == 1
|
||||
|
|
@ -4,8 +4,8 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console.error import (
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue