Merge branch 'main' into feat/model-auth

This commit is contained in:
zxhlyh 2025-08-18 18:05:03 +08:00
commit 6aa5273c5e
304 changed files with 11488 additions and 2944 deletions

File diff suppressed because it is too large Load Diff

View File

@ -82,7 +82,7 @@ jobs:
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 10
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
@ -95,10 +95,12 @@ jobs:
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint
docker-compose-template:

View File

@ -46,7 +46,7 @@ jobs:
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 10
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
@ -59,10 +59,12 @@ jobs:
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Create Pull Request

View File

@ -35,7 +35,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
version: 10
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
@ -48,8 +48,10 @@ jobs:
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test

2
.gitignore vendored
View File

@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info
!.vscode/README.md
pyrightconfig.json
api/.vscode
# vscode Code History Extension
.history
.idea/

83
CLAUDE.md Normal file
View File

@ -0,0 +1,83 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
The codebase consists of:
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Development Commands
### Backend (API)
All Python commands must be prefixed with `uv run --project api`:
```bash
# Start development servers
./dev/start-api # Start API server
./dev/start-worker # Start Celery worker
# Run tests
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
# Code quality
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking
```
### Frontend (Web)
```bash
cd web
pnpm lint # Run ESLint
pnpm eslint-fix # Fix ESLint issues
pnpm test # Run Jest tests
```
## Testing Guidelines
### Backend Testing
- Use `pytest` for all backend tests
- Write tests first (TDD approach)
- Test structure: Arrange-Act-Assert
## Code Style Requirements
### Python
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
### TypeScript/JavaScript
- Strict TypeScript configuration
- ESLint with Prettier integration
- Avoid `any` type
## Important Notes
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
- **Comments**: Only write meaningful comments that explain "why", not "what"
- **File Creation**: Always prefer editing existing files over creating new ones
- **Documentation**: Don't create documentation files unless explicitly requested
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
## Common Development Tasks
### Adding a New API Endpoint
1. Create controller in `/api/controllers/`
2. Add service logic in `/api/services/`
3. Update routes in controller's `__init__.py`
4. Write tests in `/api/tests/`
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker

View File

@ -225,7 +225,8 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Using Alibaba Cloud Computing Nest

View File

@ -208,7 +208,8 @@ docker compose up -d
##### AWS
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK بواسطة @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK بواسطة @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### استخدام Alibaba Cloud للنشر
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)

View File

@ -225,7 +225,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud ব্যবহার করে ডিপ্লয়

View File

@ -223,7 +223,8 @@ docker compose up -d
使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### 使用 阿里云计算巢 部署

View File

@ -220,7 +220,8 @@ Stellen Sie Dify mit nur einem Klick mithilfe von [terraform](https://www.terraf
Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -220,7 +220,8 @@ Despliega Dify en una plataforma en la nube con un solo clic utilizando [terrafo
Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -218,7 +218,8 @@ Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](http
Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK par @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK par @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -219,7 +219,8 @@ docker compose up -d
[CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします
##### AWS
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [@KevinZhaoによるAWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [@tmokmssによるAWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)

View File

@ -218,7 +218,8 @@ wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/
wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH.
##### AWS
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK qachlot @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK qachlot @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -212,7 +212,8 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
[CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포
##### AWS
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [KevinZhao의 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [tmokmss의 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -217,7 +217,8 @@ Implante o Dify na Plataforma Cloud com um único clique usando [terraform](http
Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -218,7 +218,8 @@ namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.
Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -211,7 +211,8 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
[CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın
##### AWS
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK tarafından @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK tarafından @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -223,7 +223,8 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
### AWS
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [由 @KevinZhao 提供的 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [由 @tmokmss 提供的 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### 使用 阿里云计算巢進行部署

View File

@ -213,7 +213,8 @@ Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột b
Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK bởi @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
- [AWS CDK bởi @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws)
#### Alibaba Cloud

View File

@ -42,6 +42,15 @@ REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
# SSL configuration for Redis (when REDIS_USE_SSL=true)
REDIS_SSL_CERT_REQS=CERT_NONE
# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
REDIS_SSL_CA_CERTS=
# Path to CA certificate file for SSL verification
REDIS_SSL_CERTFILE=
# Path to client certificate file for SSL authentication
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# redis Sentinel configuration.

View File

@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
ext_login,
ext_mail,
ext_migrate,
ext_orjson,
ext_otel,
ext_proxy_fix,
ext_redis,
@ -67,6 +68,7 @@ def initialize_extensions(app: DifyApp):
ext_logging,
ext_warnings,
ext_import_modules,
ext_orjson,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,

View File

@ -36,6 +36,7 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
@click.command("reset-password", help="Reset the account password.")
@ -1202,3 +1203,138 @@ def setup_system_tool_oauth_client(provider, client_params):
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.
Args:
batch_size: Maximum number of orphaned app IDs to return
Returns:
List of app IDs that have draft variables but don't exist in the apps table
"""
query = """
SELECT DISTINCT wdv.app_id
FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
)
LIMIT :batch_size
"""
with db.engine.connect() as conn:
result = conn.execute(sa.text(query), {"batch_size": batch_size})
return [row[0] for row in result]
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app.
Returns:
Dictionary with statistics about orphaned variables
"""
query = """
SELECT
wdv.app_id,
COUNT(*) as variable_count
FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
)
GROUP BY wdv.app_id
ORDER BY variable_count DESC
"""
with db.engine.connect() as conn:
result = conn.execute(sa.text(query))
orphaned_by_app = {row[0]: row[1] for row in result}
total_orphaned = sum(orphaned_by_app.values())
app_count = len(orphaned_by_app)
return {
"total_orphaned_variables": total_orphaned,
"orphaned_app_count": app_count,
"orphaned_by_app": orphaned_by_app,
}
@click.command()
@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting")
@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)")
@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)")
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
def cleanup_orphaned_draft_variables(
dry_run: bool,
batch_size: int,
max_apps: int | None,
force: bool = False,
):
"""
Clean up orphaned draft variables from the database.
This script finds and removes draft variables that belong to apps
that no longer exist in the database.
"""
logger = logging.getLogger(__name__)
# Get statistics
stats = _count_orphaned_draft_variables()
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
if stats["total_orphaned_variables"] == 0:
logger.info("No orphaned draft variables found. Exiting.")
return
if dry_run:
logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
:10
]: # Show top 10
logger.info(" App %s: %s variables", app_id, count)
if len(stats["orphaned_by_app"]) > 10:
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
return
# Confirm deletion
if not force:
click.confirm(
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
abort=True,
)
total_deleted = 0
processed_apps = 0
while True:
if max_apps and processed_apps >= max_apps:
logger.info("Reached maximum app limit (%s). Stopping.", max_apps)
break
orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10)
if not orphaned_app_ids:
logger.info("No more orphaned draft variables found.")
break
for app_id in orphaned_app_ids:
if max_apps and processed_apps >= max_apps:
break
try:
deleted_count = delete_draft_variables_batch(app_id, batch_size)
total_deleted += deleted_count
processed_apps += 1
logger.info("Deleted %s variables for app %s", deleted_count, app_id)
except Exception:
logger.exception("Error processing app %s", app_id)
continue
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)

View File

@ -39,6 +39,26 @@ class RedisConfig(BaseSettings):
default=False,
)
REDIS_SSL_CERT_REQS: str = Field(
description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)",
default="CERT_NONE",
)
REDIS_SSL_CA_CERTS: Optional[str] = Field(
description="Path to the CA certificate file for SSL verification",
default=None,
)
REDIS_SSL_CERTFILE: Optional[str] = Field(
description="Path to the client certificate file for SSL authentication",
default=None,
)
REDIS_SSL_KEYFILE: Optional[str] = Field(
description="Path to the client private key file for SSL authentication",
default=None,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Enable Redis Sentinel mode for high availability",
default=False,

View File

@ -1,5 +1,7 @@
from werkzeug.exceptions import HTTPException
from libs.exception import BaseHTTPException
class FilenameNotExistsError(HTTPException):
code = 400
@ -9,3 +11,27 @@ class FilenameNotExistsError(HTTPException):
class RemoteFileUploadError(HTTPException):
code = 400
description = "Error uploading remote file."
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400

View File

@ -3,9 +3,8 @@ from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,

View File

@ -1,6 +1,7 @@
import logging
import flask_login
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
@ -24,6 +25,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@ -115,6 +117,10 @@ class ChatMessageApi(Resource):
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
account = flask_login.current_user
try:

View File

@ -79,18 +79,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."

View File

@ -1,3 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -10,6 +12,9 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -107,6 +112,119 @@ class RuleStructuredOutputGenerateApi(Resource):
return structured_output
class InstructionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).filter(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
class InstructionGenerationTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self) -> dict:
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT}
case "code":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
api.add_resource(InstructionGenerateApi, "/instruction-generate")
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")

View File

@ -27,7 +27,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -124,17 +124,34 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(
app_model=app_model,
message_id=str(args["message_id"]),
user=current_user,
rating=args.get("rating"),
content=None,
)
except MessageNotExistsError:
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
db.session.commit()
return {"result": "success"}

View File

@ -23,6 +23,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -185,6 +186,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
@ -373,6 +378,10 @@ class DraftWorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model,

View File

@ -163,11 +163,11 @@ class WorkflowVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)
return workflow_vars

View File

@ -1,30 +1,6 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."

View File

@ -76,30 +76,6 @@ class EmailSendIpLimitError(BaseHTTPException):
code = 429
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."

View File

@ -8,7 +8,13 @@ from werkzeug.exceptions import Forbidden
import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import FilenameNotExistsError
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -18,13 +24,6 @@ from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
from .error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
PREVIEW_WORDS_LIMIT = 3000

View File

@ -7,18 +7,17 @@ from flask_restful import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.common.errors import (
FileTooLargeError,
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account
from services.file_service import FileService
from .error import (
FileTooLargeError,
UnsupportedFileTypeError,
)
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)

View File

@ -32,7 +32,7 @@ class VersionApi(Resource):
return result
try:
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
except Exception as error:
logging.warning("Check update version error: %s.", str(error))
result["version"] = args.get("current_version")

View File

@ -7,15 +7,15 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.datasets.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,

View File

@ -1,7 +0,0 @@
from libs.exception import BaseHTTPException
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415

View File

@ -5,8 +5,8 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from services.account_service import TenantService
from services.file_service import FileService

View File

@ -4,8 +4,8 @@ from flask import Response
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db

View File

@ -5,11 +5,13 @@ from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import file_fields

View File

@ -1,5 +1,3 @@
import json
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
@ -136,12 +134,15 @@ class ConversationVariableDetailApi(Resource):
variable_id = str(variable_id)
parser = reqparse.RequestParser()
parser.add_argument("value", required=True, location="json")
# using lambda is for passing the already-typed value without modification
# if no lambda, it will be converted to string
# the string cannot be converted using json.loads
parser.add_argument("value", required=True, location="json", type=lambda x: x)
args = parser.parse_args()
try:
return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
app_model, conversation_id, variable_id, end_user, args["value"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -85,30 +85,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class FileNotFoundError(BaseHTTPException):
error_code = "file_not_found"
description = "The requested file was not found."

View File

@ -2,14 +2,14 @@ from flask import request
from flask_restful import Resource, marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser

View File

@ -6,15 +6,15 @@ from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
ProviderNotInitializeError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,

View File

@ -1,30 +1,6 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."

View File

@ -97,30 +97,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required"
description = "Web app authentication required."

View File

@ -2,8 +2,13 @@ from flask import request
from flask_restful import marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from services.file_service import FileService

View File

@ -5,15 +5,17 @@ from flask_restful import marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.common.errors import (
FileTooLargeError,
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from services.file_service import FileService
from .error import FileTooLargeError, UnsupportedFileTypeError
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)

View File

@ -74,6 +74,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.enums import CreatorUserRole
@ -896,6 +897,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [

View File

@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(

View File

@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(

View File

@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
)
@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._app_mode = app_mode
self._message_id = str(message_id)
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return MessageQueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event,
)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue

View File

@ -57,6 +57,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@ -389,6 +390,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if llm_result.message.content
else ""
)
message.updated_at = naive_utc_now()
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit

View File

@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.variables.utils import SegmentJSONEncoder
from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):
@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]:
"""
Retrieve the trace_id from the request.
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
Priority:
1. header ('X-Trace-Id')
2. parameters
3. JSON body
4. Current OpenTelemetry context (if enabled)
5. OpenTelemetry traceparent header (if present and valid)
Returns None if no valid trace_id is provided.
"""
trace_id = request.headers.get("X-Trace-Id")
if not trace_id:
trace_id = request.args.get("trace_id")
if not trace_id and getattr(request, "is_json", False):
json_data = getattr(request, "json", None)
if json_data:
trace_id = json_data.get("trace_id")
if not trace_id:
trace_id = get_trace_id_from_otel_context()
if not trace_id:
traceparent = request.headers.get("traceparent")
if traceparent:
trace_id = parse_traceparent_header(traceparent)
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
return trace_id
return None
@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
if trace_id:
return {"external_trace_id": trace_id}
return {}
def get_trace_id_from_otel_context() -> Optional[str]:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.
Returns None if:
1. OpenTelemetry SDK is not installed or enabled.
2. There is no active span or trace context.
"""
try:
from opentelemetry.trace import SpanContext, get_current_span
from opentelemetry.trace.span import INVALID_TRACE_ID
span = get_current_span()
if not span:
return None
span_context: SpanContext = span.get_span_context()
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
return None
trace_id_hex = f"{span_context.trace_id:032x}"
return trace_id_hex
except Exception:
return None
def parse_traceparent_header(traceparent: str) -> Optional[str]:
"""
Parse the `traceparent` header to extract the trace_id.
Expected format:
'version-trace_id-span_id-flags'
Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
"""
try:
parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
except Exception:
pass
return None

View File

@ -1,6 +1,7 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Optional, cast
import json_repair
@ -11,6 +12,8 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
class LLMGenerator:
@ -388,3 +394,181 @@ class LLMGenerator:
except Exception as e:
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict:
app: App | None = db.session.query(App).filter(App.id == flow_id).first()
last_run: Message | None = (
db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
@staticmethod
def instruction_modify_workflow(
tenant_id: str,
flow_id: str,
node_id: str,
current: str,
instruction: str,
model_config: dict,
ideal_output: str | None,
) -> dict:
from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).filter(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
raise ValueError("Workflow not found for the given app model.")
last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
try:
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
except Exception:
try:
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
"type"
]
except Exception:
node_type = "llm"
if not last_run: # Node is not executed yet
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type=node_type,
ideal_output=ideal_output,
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
if not raw_agent_log:
return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent) -> dict:
return {
"status": event.status,
"error": event.error,
"data": event.data,
}
return [dict_of_event(event) for event in parsed]
last_run_dict = {
"inputs": last_run.inputs_dict,
"status": last_run.status,
"error": last_run.error,
"agent_log": agent_log_of(last_run),
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=last_run.error,
instruction=instruction,
node_type=last_run.node_type,
ideal_output=ideal_output,
)
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: dict,
last_run: dict | None,
current: str | None,
error_message: str | None,
instruction: str,
node_type: str,
ideal_output: str | None,
) -> dict:
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}"
injected_instruction = instruction
if LAST_RUN in injected_instruction:
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
if CURRENT in injected_instruction:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
match node_type:
case "llm", "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
case "code":
system_prompt = LLM_MODIFY_CODE_SYSTEM
case _:
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
prompt_messages = [
SystemPromptMessage(content=system_prompt),
UserPromptMessage(
content=json.dumps(
{
"current": current,
"last_run": last_run,
"instruction": injected_instruction,
"ideal_output": ideal_output,
}
)
),
]
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
generated_raw = cast(str, response.message.content)
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -309,3 +309,116 @@ eg:
Here is the JSON schema:
{{schema}}
""" # noqa: E501
LLM_MODIFY_PROMPT_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.",
"properties": {
"current": {
"type": "string",
"description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed."
},
"last_run": {
"type": "object",
"description": "The output result from the large language model after receiving the prompt.",
},
"instruction": {
"type": "string",
"description": "User's instruction to edit the current prompt"
},
"ideal_output": {
"type": "string",
"description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal."
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
LLM_MODIFY_CODE_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.",
"properties": {
"current": {
"type": "string",
"description": "The code before modification."
},
"last_run": {
"type": "object",
"description": "The result of the code.",
},
"message": {
"type": "string",
"description": "User's instruction to edit the current code"
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
When you are modifying the code, you should remember:
- Do not use print, this not work in dify sandbox.
- Do not try dangerous call like deleting files. It's PROHIBITED.
- Do not use any library that is not built-in in with Python.
- Get inputs from the parameters of the function and have explicit type annotations.
- Write proper imports at the top of the code.
- Use return statement to return the result.
- You should return a `dict`.
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}.
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""

View File

@ -10,8 +10,6 @@ from core.mcp.types import (
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
class OAuthClientProvider:
mcp_provider: MCPToolProvider

View File

@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import EventSource, ServerSentEvent
from sseclient import SSEClient
from core.mcp import types
@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
def remove_request_params(url: str) -> str:
"""Remove request parameters from URL, keeping only the path."""
return urljoin(url, urlparse(url).path)
class SSETransport:
"""SSE client transport implementation."""
@ -114,7 +110,7 @@ class SSETransport:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Handle a single SSE event.
Args:
@ -130,7 +126,7 @@ class SSETransport:
case _:
logger.warning("Unknown SSE event: %s", sse.event)
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Read and process SSE events.
Args:
@ -225,7 +221,7 @@ class SSETransport:
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source,
event_source: EventSource,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.

View File

@ -16,13 +16,14 @@ from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
logger = logging.getLogger(__name__)
class MCPServerStreamableHTTPRequestHandler:
"""
Apply to MCP HTTP streamable server with stateless http
"""
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):

View File

@ -1,6 +1,10 @@
import json
from collections.abc import Generator
from contextlib import AbstractContextManager
import httpx
import httpx_sse
from httpx_sse import connect_sse
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client(
)
def ssrf_proxy_sse_connect(url, **kwargs):
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
to prevent SSRF attacks when connecting to external endpoints. It returns
a context manager that yields an EventSource object for SSE streaming.
The function handles HTTP client creation and cleanup automatically, but
also accepts a pre-configured client via kwargs.
Args:
url: The SSE endpoint URL
**kwargs: Additional arguments passed to the SSE connection
url (str): The SSE endpoint URL to connect to
**kwargs: Additional arguments passed to the SSE connection, including:
- client (httpx.Client, optional): Pre-configured HTTP client.
If not provided, one will be created with SSRF protection.
- method (str, optional): HTTP method to use, defaults to "GET"
- headers (dict, optional): HTTP headers to include in the request
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
Returns:
EventSource object for SSE streaming
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
object for SSE streaming. The EventSource provides access to server-sent events.
Example:
```python
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
for sse in event_source.iter_sse():
print(sse.event, sse.data)
```
Note:
If a client is not provided in kwargs, one will be automatically created
with SSRF protection based on the application's configuration. If an
exception occurs during connection, any automatically created client
will be cleaned up automatically.
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs):
raise
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
def create_mcp_error_response(
request_id: int | str | None, code: int, message: str, data=None
) -> Generator[bytes, None, None]:
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(

View File

@ -151,12 +151,9 @@ def jsonable_encoder(
return format(obj, "f")
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
for key, value in obj.items():
if (
(not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa")))
and (value is not None or not exclude_none)
and key in allowed_keys
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
value is not None or not exclude_none
):
encoded_key = jsonable_encoder(
key,

View File

@ -4,15 +4,15 @@ from collections.abc import Sequence
from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
convert_datetime_to_nanoseconds,
convert_string_to_id,
convert_to_span_id,
convert_to_trace_id,
create_link,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance):
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
links=links,
)
self.trace_client.add_span(message_span)
@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance):
message_id = trace_info.message_id
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData(
@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance):
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
links=links,
)
self.trace_client.add_span(dataset_retrieval_span)
@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
tool_span = SpanData(
trace_id=trace_id,
@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
links=links,
)
self.trace_client.add_span(tool_span)
@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance):
status=self.get_workflow_node_status(node_execution),
)
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
def add_workflow_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(message_span)
@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(workflow_span)
@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
suggested_question_span = SpanData(
trace_id=trace_id,
@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(suggested_question_span)

View File

@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -166,6 +167,16 @@ class SpanBuilder:
return span
def create_link(trace_id_str: str) -> Link:
placeholder_span_id = 0x0000000000000000
trace_id = int(trace_id_str, 16)
span_context = SpanContext(
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
)
return Link(span_context)
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:

View File

@ -523,7 +523,7 @@ class ProviderManager:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.

View File

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from typing import Any, Optional
import orjson
from pydantic import BaseModel
from configs import dify_config
@ -134,13 +134,13 @@ class Jieba(BaseKeyword):
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
@ -156,12 +156,11 @@ class Jieba(BaseKeyword):
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
dataset_keyword_table.keyword_table = dumps_with_sets(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
}
)
db.session.add(dataset_keyword_table)
db.session.commit()
@ -252,8 +251,13 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
def set_orjson_default(obj: Any) -> Any:
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

View File

@ -4,8 +4,8 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector):
)
try:
if self._hybrid_search_enabled:
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""")
self._client.create_fts_idx_with_fts_index_param(
table_name=self._collection_name,
fts_idx_param=FtsIndexParam(
index_name="fulltext_index_for_col_text",
field_names=["text"],
parser_type=FtsParser.IK,
),
)
except Exception as e:
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
+ "to support fulltext index and vector index in the same table",
e,
)
self._client.refresh_metadata([self._collection_name])
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector):
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
distance_func=l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,

View File

@ -331,6 +331,12 @@ class QdrantVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score_threshold >= 1:
# return empty list because some versions of qdrant may response with 400 bad request
# and at the same time, the score_threshold with value 1 may be valid for other vector stores
return []
filter = models.Filter(
must=[
models.FieldCondition(
@ -355,7 +361,7 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=float(kwargs.get("score_threshold") or 0.0),
score_threshold=score_threshold,
)
docs = []
for result in results:
@ -363,7 +369,6 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

View File

@ -1 +0,0 @@

View File

@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum):
:param value: mode value
:return: mode
"""
# 'api_key' deprecated in PR #21656
# normalize & tiny alias for backward compatibility
v = (value or "").strip().lower()
if v == "api_key":
v = cls.API_KEY_HEADER.value
for mode in cls:
if mode.value == value:
if mode.value == v:
return mode
raise ValueError(f"invalid mode value {value}")
valid = ", ".join(m.value for m in cls)
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
class ToolInvokeMessage(BaseModel):

View File

@ -1,5 +1,7 @@
import json
from collections.abc import Iterable, Sequence
from typing import Any
import orjson
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment
@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
return selectors
class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)
def segment_orjson_default(o: Any) -> Any:
"""Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
"""JSON dumps with segment support using orjson"""
option = orjson.OPT_NON_STR_KEYS
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")

View File

@ -276,17 +276,26 @@ class Executor:
encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
if authorization.config.header and authorization.config.api_key:
headers[authorization.config.header] = authorization.config.api_key
# Handle Content-Type for multipart/form-data requests
# Fix for issue #22880: Missing boundary when using multipart/form-data
# Fix for issue #23829: Missing boundary when using multipart/form-data
body = self.node_data.body
if body and body.type == "form-data":
# For multipart/form-data with files, let httpx handle the boundary automatically
# by not setting Content-Type header when files are present
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
# Only set Content-Type when there are no actual files
# This ensures httpx generates the correct boundary
# For multipart/form-data with files (including placeholder files),
# remove any manually set Content-Type header to let httpx handle
# For multipart/form-data, if any files are present (including placeholder files),
# we must remove any manually set Content-Type header. This is because httpx needs to
# automatically set the Content-Type and boundary for multipart encoding whenever files
# are included, even if they are placeholders, to avoid boundary issues and ensure correct
# file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the
# boundary, resulting in invalid requests.
if self.files:
# Remove Content-Type if it was manually set to avoid boundary issues
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
else:
# No files at all, set Content-Type manually
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = "multipart/form-data"
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:

View File

@ -318,33 +318,6 @@ class ToolNode(BaseNode):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": message.message.text,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])

View File

@ -5,7 +5,7 @@ import click
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created
from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document

View File

@ -1,4 +1,6 @@
import ssl
from datetime import timedelta
from typing import Any, Optional
import pytz
from celery import Celery, Task # type: ignore
@ -8,6 +10,40 @@ from configs import dify_config
from dify_app import DifyApp
def _get_celery_ssl_options() -> Optional[dict[str, Any]]:
"""Get SSL configuration for Celery broker/backend connections."""
# Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.REDIS_USE_SSL:
return None
# Check if Celery is actually using Redis
broker_is_redis = dify_config.CELERY_BROKER_URL and (
dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://")
)
if not broker_is_redis:
return None
# Map certificate requirement strings to SSL constants
cert_reqs_map = {
"CERT_NONE": ssl.CERT_NONE,
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
"CERT_REQUIRED": ssl.CERT_REQUIRED,
}
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_options = {
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
}
return ssl_options
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery:
task_ignore_result=True,
)
# Add SSL options to the Celery configuration
ssl_options = {
"ssl_cert_reqs": None,
"ssl_ca_certs": None,
"ssl_certfile": None,
"ssl_keyfile": None,
}
celery_app.conf.update(
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options,
@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery:
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
)
if dify_config.BROKER_USE_SSL:
# Apply SSL configuration if enabled
ssl_options = _get_celery_ssl_options()
if ssl_options:
celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
broker_use_ssl=ssl_options,
# Also apply SSL to the backend if it's Redis
redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None,
)
if dify_config.LOG_FILE:
@ -113,7 +145,7 @@ def init_app(app: DifyApp) -> Celery:
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")
beat_schedule["check_upgradable_plugin_task"] = {
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",

View File

@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
@ -42,6 +43,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
cleanup_orphaned_draft_variables,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -0,0 +1,8 @@
from flask_orjson import OrjsonProvider
from dify_app import DifyApp
def init_app(app: DifyApp) -> None:
"""Initialize Flask-Orjson extension for faster JSON serialization"""
app.json = OrjsonProvider(app)

View File

@ -1,5 +1,6 @@
import functools
import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
@ -116,76 +117,132 @@ class RedisClientWrapper:
redis_client: RedisClientWrapper = RedisClientWrapper()
def init_app(app: DifyApp):
global redis_client
connection_class: type[Union[Connection, SSLConnection]] = Connection
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
if resp_protocol >= 3:
clientside_cache_config = CacheConfig()
else:
raise ValueError("Client side cache is only supported in RESP3")
else:
clientside_cache_config = None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
"""Get SSL configuration for Redis connection."""
if not dify_config.REDIS_USE_SSL:
return Connection, {}
redis_params: dict[str, Any] = {
cert_reqs_map = {
"CERT_NONE": ssl.CERT_NONE,
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
"CERT_REQUIRED": ssl.CERT_REQUIRED,
}
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_kwargs = {
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
}
return SSLConnection, ssl_kwargs
def _get_cache_configuration() -> CacheConfig | None:
"""Get client-side cache configuration if enabled."""
if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
return None
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
if resp_protocol < 3:
raise ValueError("Client side cache is only supported in RESP3")
return CacheConfig()
def _get_base_redis_params() -> dict[str, Any]:
"""Get base Redis connection parameters."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password
"password": dify_config.REDIS_PASSWORD or None,
"db": dify_config.REDIS_DB,
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
}
if dify_config.REDIS_USE_SENTINEL:
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, (
"REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True"
)
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]
sentinel = Sentinel(
sentinel_hosts,
sentinel_kwargs={
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
},
)
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master)
elif dify_config.REDIS_USE_CLUSTERS:
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
nodes = [
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
redis_client.initialize(
RedisCluster(
startup_nodes=nodes,
password=dify_config.REDIS_CLUSTERS_PASSWORD,
protocol=resp_protocol,
cache_config=clientside_cache_config,
)
)
else:
redis_params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
}
)
pool = redis.ConnectionPool(**redis_params)
redis_client.initialize(redis.Redis(connection_pool=pool))
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
"""Create Redis client using Sentinel configuration."""
if not dify_config.REDIS_SENTINELS:
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
if not dify_config.REDIS_SENTINEL_SERVICE_NAME:
raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True")
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
sentinel = Sentinel(
sentinel_hosts,
sentinel_kwargs={
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
},
)
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
return master
def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
"""Create Redis cluster client."""
if not dify_config.REDIS_CLUSTERS:
raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True")
nodes = [
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
cluster: RedisCluster = RedisCluster(
startup_nodes=nodes,
password=dify_config.REDIS_CLUSTERS_PASSWORD,
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
cache_config=_get_cache_configuration(),
)
return cluster
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
redis_params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
)
if ssl_kwargs:
redis_params.update(ssl_kwargs)
pool = redis.ConnectionPool(**redis_params)
client: redis.Redis = redis.Redis(connection_pool=pool)
return client
def init_app(app: DifyApp):
"""Initialize Redis client and attach it to the app."""
global redis_client
# Determine Redis mode and create appropriate client
if dify_config.REDIS_USE_SENTINEL:
redis_params = _get_base_redis_params()
client = _create_sentinel_client(redis_params)
elif dify_config.REDIS_USE_CLUSTERS:
client = _create_cluster_client()
else:
redis_params = _get_base_redis_params()
client = _create_standalone_client(redis_params)
# Initialize the wrapper and attach to app
redis_client.initialize(client)
app.extensions["redis"] = redis_client

View File

@ -248,6 +248,8 @@ def _get_remote_file_info(url: str):
# Initialize mime_type from filename as fallback
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
@ -256,7 +258,12 @@ def _get_remote_file_info(url: str):
filename = str(content_disposition.split("filename=")[-1].strip('"'))
# Re-guess mime_type from updated filename
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
file_size = int(resp.headers.get("Content-Length", file_size))
# Fallback to Content-Type header if mime_type is still empty
if not mime_type:
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
return mime_type, filename, file_size

View File

@ -1153,7 +1153,7 @@ class WorkflowDraftVariable(Base):
value: The Segment object to store as the variable's value.
"""
self.__value = value
self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder)
self.value = variable_utils.dumps_with_segments(value)
self.value_type = value.value_type
def get_node_id(self) -> str | None:

View File

@ -5,8 +5,7 @@ check_untyped_defs = True
cache_fine_grained = True
sqlite_cache = True
exclude = (?x)(
core/model_runtime/model_providers/
| tests/
tests/
| migrations/
)

View File

@ -18,6 +18,7 @@ dependencies = [
"flask-cors~=6.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
"flask-orjson~=2.0.0",
"flask-restful~=0.3.10",
"flask-sqlalchemy~=3.1.1",
"gevent~=24.11.1",
@ -204,7 +205,7 @@ vdb = [
"pgvector==0.2.5",
"pymilvus~=2.5.0",
"pymochow==1.3.1",
"pyobvector~=0.1.6",
"pyobvector~=0.2.15",
"qdrant-client==1.9.0",
"tablestore==6.2.0",
"tcvectordb~=1.6.4",

View File

@ -1,5 +1,6 @@
import datetime
import time
from typing import Optional, TypedDict
import click
from sqlalchemy import func, select
@ -14,168 +15,140 @@ from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Documen
from services.feature_service import FeatureService
class CleanupConfig(TypedDict):
clean_day: datetime.datetime
plan_filter: Optional[str]
add_logs: bool
@app.celery.task(queue="dataset")
def clean_unused_datasets_task():
click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING
plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING
start_at = time.perf_counter()
plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting)
plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting)
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > plan_sandbox_clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < plan_sandbox_clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Define cleanup configurations
cleanup_configs: list[CleanupConfig] = [
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING),
"plan_filter": None,
"add_logs": True,
},
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
"plan_filter": "sandbox",
"add_logs": False,
},
]
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < plan_sandbox_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
)
for config in cleanup_configs:
clean_day = config["clean_day"]
plan_filter = config["plan_filter"]
add_logs = config["add_logs"]
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
# add auto disable log
documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > clean_day,
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > plan_pro_clean_day,
.group_by(Document.dataset_id)
.subquery()
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < plan_pro_clean_day,
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < plan_pro_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
)
.order_by(Dataset.created_at.desc())
)
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
features_cache_key = f"features:{dataset.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(dataset.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == "sandbox":
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
should_clean = True
# Check plan filter if specified
if plan_filter:
features_cache_key = f"features:{dataset.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(dataset.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
should_clean = plan == plan_filter
if should_clean:
# Add auto disable log if required
if add_logs:
documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# Remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# Update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
{Document.enabled: False}
)
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
# update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green"))

View File

@ -103,10 +103,10 @@ class ConversationService:
@classmethod
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
field_value = getattr(reference_conversation, sort_field)
if sort_direction == desc:
if sort_direction is desc:
return getattr(Conversation, sort_field) < field_value
else:
return getattr(Conversation, sort_field) > field_value
return getattr(Conversation, sort_field) > field_value
@classmethod
def rename(
@ -147,7 +147,7 @@ class ConversationService:
app_model.tenant_id, message.query, conversation.id, app_model.id
)
conversation.name = name
except:
except Exception:
pass
db.session.commit()
@ -277,6 +277,11 @@ class ConversationService:
# Validate that the new value type matches the expected variable type
expected_type = SegmentType(current_variable.value_type)
# There is showing number in web ui but int in db
if expected_type == SegmentType.INTEGER:
expected_type = SegmentType.NUMBER
if not expected_type.is_valid(new_value):
inferred_type = SegmentType.infer_segment_type(new_value)
raise ConversationVariableTypeMismatchError(

View File

@ -47,7 +47,9 @@ class OAuthProxyService(BasePluginClient):
if not context_id:
raise ValueError("context_id is required")
# get data from redis
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}"
data = redis_client.get(key)
if not data:
raise ValueError("context_id is invalid")
redis_client.delete(key)
return json.loads(data)

View File

@ -33,7 +33,11 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
)
from repositories.factory import DifyAPIRepositoryFactory
@ -62,6 +66,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
_delete_draft_variables(app_id)
end_at = time.perf_counter()
logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
@ -91,7 +96,12 @@ def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_site,
"site",
)
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
@ -111,7 +121,10 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token"
"""select id from api_tokens where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_api_token,
"api token",
)
@ -273,7 +286,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
db.session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message"
"""select id from messages where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_message,
"message",
)
@ -329,6 +345,56 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_draft_variables(app_id: str):
"""Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
"""
Delete draft variables for an app in batches.
Args:
app_id: The ID of the app whose draft variables should be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
total_deleted = 0
while True:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs
query_sql = """
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
draft_var_ids = [row[0] for row in result]
if not draft_var_ids:
break
# Delete the batch
delete_sql = """
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
total_deleted += batch_deleted
logging.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
logging.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
return total_deleted
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:

View File

@ -0,0 +1,214 @@
import uuid
import pytest
from sqlalchemy import delete
from core.variables.segments import StringSegment
from models import Tenant, db
from models.model import App
from models.workflow import WorkflowDraftVariable
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
@pytest.fixture
def app_and_tenant(flask_req_ctx):
tenant_id = uuid.uuid4()
tenant = Tenant(
id=tenant_id,
name="test_tenant",
)
db.session.add(tenant)
app = App(
tenant_id=tenant_id, # Now tenant.id will have a value
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.flush()
yield (tenant, app)
# Cleanup with proper error handling
db.session.delete(app)
db.session.delete(tenant)
class TestDeleteDraftVariablesIntegration:
@pytest.fixture
def setup_test_data(self, app_and_tenant):
"""Create test data with apps and draft variables."""
tenant, app = app_and_tenant
# Create a second app for testing
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.commit()
# Create draft variables for both apps
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var1)
variables_app1.append(var1)
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var2)
variables_app2.append(var2)
# Commit all the variables to the database
db.session.commit()
yield {
"app1": app,
"app2": app2,
"tenant": tenant,
"variables_app1": variables_app1,
"variables_app2": variables_app2,
}
# Cleanup - refresh session and check if objects still exist
db.session.rollback() # Clear any pending changes
# Clean up remaining variables
cleanup_query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_query)
# Clean up app2
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
"""Test that batch deletion only removes variables for the specified app."""
data = setup_test_data
app1_id = data["app1"].id
app2_id = data["app2"].id
# Verify initial state
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_before == 5
assert app2_vars_before == 5
# Delete app1 variables
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
# Verify results
assert deleted_count == 5
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0 # All app1 variables deleted
assert app2_vars_after == 5 # App2 variables unchanged
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
"""Test batch deletion with small batch size processes all records."""
data = setup_test_data
app1_id = data["app1"].id
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
assert deleted_count == 5
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
"""Test that deleting variables for nonexistent app returns 0."""
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
assert deleted_count == 0
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
"""Test that _delete_draft_variables wrapper function works correctly."""
data = setup_test_data
app1_id = data["app1"].id
# Verify initial state
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_before == 5
# Call wrapper function
deleted_count = _delete_draft_variables(app1_id)
# Verify results
assert deleted_count == 5
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
"""Test batch deletion with larger dataset to verify batching logic."""
tenant, app = app_and_tenant
# Create many draft variables
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
variables.append(var)
variable_ids = [i.id for i in variables]
# Commit the variables to the database
db.session.commit()
try:
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining_vars == 0
finally:
query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.id.in_(variable_ids),
)
.execution_options(synchronize_session=False)
)
db.session.execute(query)

View File

@ -1,4 +1,5 @@
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest):
),
)
def search_by_vector(self):
super().search_by_vector()
# only test for qdrant, may not work on other vector stores
hits_by_vector: list[Document] = self.vector.search_by_vector(
query_vector=self.example_embedding, score_threshold=1
)
assert len(hits_by_vector) == 0
def test_qdrant_vector(setup_mock_redis):
QdrantVectorTest().run_all_tests()

View File

@ -160,6 +160,177 @@ def test_custom_authorization_header(setup_http_mock):
assert "?A=b" in data
assert "X-Header: 123" in data
# Custom authorization header should be set (may be masked)
assert "X-Auth:" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.system_variable import SystemVariable
# Create variable pool
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="test", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
# Create node data with custom auth and empty api_key
node_data = HttpRequestNodeData(
title="http",
desc="",
url="http://example.com",
method="get",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={
"type": "custom",
"api_key": "", # Empty api_key
"header": "X-Custom-Auth",
},
),
headers="",
params="",
body=None,
ssl_verify=True,
)
# Create executor
executor = Executor(
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
)
# Get assembled headers
headers = executor._assembling_headers()
# When api_key is empty, the custom header should NOT be set
assert "X-Custom-Auth" not in headers
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
"""
Test that when switching from custom to bearer authorization,
the custom header settings don't interfere with bearer token.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "bearer",
"api_key": "test-token",
"header": "", # Empty header - should default to Authorization
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# In bearer mode, should use Authorization header (value is masked with *)
assert "Authorization: " in data
# Should contain masked Bearer token
assert "*" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
"""
Test that when switching from custom to basic authorization,
the custom header settings don't interfere with basic auth.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "basic",
"api_key": "user:pass",
"header": "", # Empty header - should default to Authorization
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# In basic mode, should use Authorization header (value is masked with *)
assert "Authorization: " in data
# Should contain masked Basic credentials
assert "*" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_with_empty_api_key(setup_http_mock):
"""
Test that custom authorization doesn't set header when api_key is empty.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "custom",
"api_key": "", # Empty api_key
"header": "X-Custom-Auth",
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Custom header should NOT be set when api_key is empty
assert "X-Custom-Auth:" not in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
@ -239,6 +410,7 @@ def test_json(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_x_www_form_urlencoded(setup_http_mock):
node = init_http_node(
config={
@ -285,6 +457,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_form_data(setup_http_mock):
node = init_http_node(
config={
@ -334,6 +507,7 @@ def test_form_data(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_none_data(setup_http_mock):
node = init_http_node(
config={
@ -366,6 +540,7 @@ def test_none_data(setup_http_mock):
assert "123123123" not in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_mock_404(setup_http_mock):
node = init_http_node(
config={
@ -394,6 +569,7 @@ def test_mock_404(setup_http_mock):
assert "Not Found" in resp.get("body", "")
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_multi_colons_parse(setup_http_mock):
node = init_http_node(
config={

View File

@ -0,0 +1,885 @@
import copy
import pytest
from faker import Faker
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Integration tests for AdvancedPromptTemplateService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
# This service doesn't have external dependencies, but we keep the pattern
# for consistency with other test files
return {}
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for Baichuan model.
This test verifies:
- Proper prompt generation for Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for Baichuan model
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for Baichuan model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for common models.
This test verifies:
- Proper prompt generation for non-Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for common model
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for common model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_case_insensitive_baichuan_detection(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan model detection is case insensitive.
This test verifies:
- Model name detection works regardless of case
- Proper prompt template selection for different case variations
"""
fake = Faker()
# Test different case variations
test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"]
for model_name in test_cases:
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": model_name,
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify Baichuan template is used
assert result is not None
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
def test_get_common_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with completion mode.
This test verifies:
- Correct prompt template selection for chat app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation for chat app with chat mode.
This test verifies:
- Correct prompt template selection for chat app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with completion mode.
This test verifies:
- Correct prompt template selection for completion app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with chat mode.
This test verifies:
- Correct prompt template selection for completion app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_common_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test completion prompt generation with context.
This test verifies:
- Proper context integration in completion prompts
- Template structure preservation
- Context placement at the beginning
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is prepended to original text
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_completion_prompt_without_context(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test completion prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify original text is unchanged
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation with context.
This test verifies:
- Proper context integration in chat prompts
- Template structure preservation
- Context placement at the beginning of first message
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is prepended to original text
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify original text is unchanged
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_baichuan_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_chat_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test Baichuan prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Baichuan context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify Baichuan context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_baichuan_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_prompt_all_app_modes_common_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with common model.
This test verifies:
- All app modes work correctly with common models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_all_app_modes_baichuan_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with Baichuan model.
This test verifies:
- All app modes work correctly with Baichuan models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test prompt generation with edge cases.
This test verifies:
- Handling of edge case inputs
- Proper error handling
- Consistent behavior with unusual inputs
"""
fake = Faker()
# Test edge cases
edge_cases = [
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "",
},
]
for args in edge_cases:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify method handles edge cases gracefully
# Should either return a valid result or empty dict, but not crash
assert result is not None
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original templates are not modified.
This test verifies:
- Original template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original Baichuan templates are not modified.
This test verifies:
- Original Baichuan template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test consistency of context integration across different scenarios.
This test verifies:
- Context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(CONTEXT)
def test_baichuan_context_integration_consistency(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test consistency of Baichuan context integration across different scenarios.
This test verifies:
- Baichuan context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that Baichuan context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,474 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting
from services.model_load_balancing_service import ModelLoadBalancingService
class TestModelLoadBalancingService:
"""Integration tests for ModelLoadBalancingService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
):
# Setup default mock returns
mock_provider_manager_instance = mock_provider_manager.return_value
# Mock provider configuration
mock_provider_config = MagicMock()
mock_provider_config.provider.provider = "openai"
mock_provider_config.custom_configuration.provider = None
# Mock provider model setting
mock_provider_model_setting = MagicMock()
mock_provider_model_setting.load_balancing_enabled = False
mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting
# Mock provider configurations dict
mock_provider_configs = {"openai": mock_provider_config}
mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs
# Mock LBModelManager
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Mock ModelProviderFactory
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
# Mock credential schemas
mock_credential_schema = MagicMock()
mock_credential_schema.credential_form_schemas = []
# Mock provider configuration methods
mock_provider_config.extract_secret_variables.return_value = []
mock_provider_config.obfuscated_credentials.return_value = {}
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
yield {
"provider_manager": mock_provider_manager,
"lb_model_manager": mock_lb_model_manager,
"model_provider_factory": mock_model_provider_factory,
"encrypter": mock_encrypter,
"provider_config": mock_provider_config,
"provider_model_setting": mock_provider_model_setting,
"credential_schema": mock_credential_schema,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_provider_and_setting(
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
):
"""
Helper method to create a test provider and provider model setting.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
tenant_id: Tenant ID for the provider
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (provider, provider_model_setting) - Created provider and setting instances
"""
fake = Faker()
from extensions.ext_database import db
# Create provider
provider = Provider(
tenant_id=tenant_id,
provider_name="openai",
provider_type="custom",
is_valid=True,
)
db.session.add(provider)
db.session.commit()
# Create provider model setting
provider_model_setting = ProviderModelSetting(
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
enabled=True,
load_balancing_enabled=False,
)
db.session.add(provider_model_setting)
db.session.commit()
return provider, provider_model_setting
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing enablement.
This test verifies:
- Proper provider configuration retrieval
- Successful enablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for enable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.enable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.enable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.enable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing disablement.
This test verifies:
- Proper provider configuration retrieval
- Successful disablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for disable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.disable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.disable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.disable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.disable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_enable_model_load_balancing_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of load balancing configurations.
This test verifies:
- Proper provider configuration retrieval
- Successful database query for load balancing configs
- Correct return format and data structure
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Verify the config was created
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Setup mocks for get_load_balancing_configs method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Mock _get_credential_schema method
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
# Mock extract_secret_variables method
mock_provider_config.extract_secret_variables.return_value = []
# Mock obfuscated_credentials method
mock_provider_config.obfuscated_credentials.return_value = {}
# Mock LBModelManager.get_config_in_cooldown_and_ttl
mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"]
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 1
assert configs[0]["id"] == load_balancing_config.id
assert configs[0]["name"] == "config1"
assert configs[0]["enabled"] is True
assert configs[0]["in_cooldown"] is False
assert configs[0]["ttl"] == 0
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
def test_get_load_balancing_configs_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist in get_load_balancing_configs.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.get_load_balancing_configs(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_with_inherit_config(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test load balancing configs retrieval with inherit configuration.
This test verifies:
- Proper handling of inherit configuration
- Correct ordering of configurations
- Inherit config initialization when needed
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Setup mocks for inherit config scenario
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 2 # inherit config + existing config
# First config should be inherit config
assert configs[0]["name"] == "__inherit__"
assert configs[0]["enabled"] is True
# Second config should be the existing config
assert configs[1]["id"] == load_balancing_config.id
assert configs[1]["name"] == "config1"
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all()
)
assert len(inherit_configs) == 1

View File

@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.common.errors import FilenameNotExistsError
from controllers.console.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,

Some files were not shown because too many files have changed in this diff Show More