diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index c25bde87b0..39a653953e 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-npm add -g pnpm@10.13.1
+npm add -g pnpm@10.15.0
cd web && pnpm install
pipx install uv
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index f5ba498c7d..dada6229db 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -26,6 +26,7 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
-
+ - name: mdformat
+ run: |
+ uvx mdformat .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
-
diff --git a/README.md b/README.md
index 7e566a0b2f..90da1d3def 100644
--- a/README.md
+++ b/README.md
@@ -107,74 +107,6 @@ Monitor and analyze application logs and performance over time. You could contin
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
-## Feature Comparison
-
-
-
-
Feature
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Programming Approach
-
API + App-oriented
-
Python Code
-
App-oriented
-
API-oriented
-
-
-
Supported LLMs
-
Rich Variety
-
Rich Variety
-
Rich Variety
-
OpenAI-only
-
-
-
RAG Engine
-
✅
-
✅
-
✅
-
✅
-
-
-
Agent
-
✅
-
✅
-
❌
-
✅
-
-
-
Workflow
-
✅
-
❌
-
✅
-
❌
-
-
-
Observability
-
✅
-
✅
-
❌
-
❌
-
-
-
Enterprise Feature (SSO/Access control)
-
✅
-
❌
-
❌
-
❌
-
-
-
Local Deployment
-
✅
-
✅
-
✅
-
❌
-
-
-
## Using Dify
- **Cloud **
diff --git a/README_AR.md b/README_AR.md
index 044ced98ed..2451757ab5 100644
--- a/README_AR.md
+++ b/README_AR.md
@@ -68,74 +68,6 @@
**7.الواجهة الخلفية (Backend) كخدمة**: تأتي جميع عروض Dify مع APIs مطابقة، حتى يمكنك دمج Dify بسهولة في منطق أعمالك الخاص.
-## مقارنة الميزات
-
-
-
-
الميزة
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
نهج البرمجة
-
موجّه لـ تطبيق + واجهة برمجة تطبيق (API)
-
برمجة Python
-
موجه لتطبيق
-
واجهة برمجة تطبيق (API)
-
-
-
LLMs المدعومة
-
تنوع غني
-
تنوع غني
-
تنوع غني
-
فقط OpenAI
-
-
-
محرك RAG
-
✅
-
✅
-
✅
-
✅
-
-
-
الوكيل
-
✅
-
✅
-
❌
-
✅
-
-
-
سير العمل
-
✅
-
❌
-
✅
-
❌
-
-
-
الملاحظة
-
✅
-
✅
-
❌
-
❌
-
-
-
ميزات الشركات (SSO / مراقبة الوصول)
-
✅
-
❌
-
❌
-
❌
-
-
-
نشر محلي
-
✅
-
✅
-
✅
-
❌
-
-
-
## استخدام Dify
- **سحابة **
diff --git a/README_BN.md b/README_BN.md
index f5a19ab434..ef24dea171 100644
--- a/README_BN.md
+++ b/README_BN.md
@@ -106,74 +106,6 @@ LLM ফাংশন কলিং বা ReAct উপর ভিত্তি ক
**7. ব্যাকএন্ড-অ্যাজ-এ-সার্ভিস**:
ডিফাই-এর সমস্ত অফার সংশ্লিষ্ট API-সহ আছে, যাতে আপনি অনায়াসে ডিফাইকে আপনার নিজস্ব বিজনেস লজিকে ইন্টেগ্রেট করতে পারেন।
-## বৈশিষ্ট্য তুলনা
-
-
-
## 使用 Dify
- **云 **
diff --git a/README_DE.md b/README_DE.md
index 88c36019e3..a593a12abf 100644
--- a/README_DE.md
+++ b/README_DE.md
@@ -106,74 +106,6 @@ Sie können Agenten basierend auf LLM Function Calling oder ReAct definieren und
**7. Backend-as-a-Service**:
Alle Dify-Angebote kommen mit entsprechenden APIs, sodass Sie Dify mühelos in Ihre eigene Geschäftslogik integrieren können.
-## Vergleich der Merkmale
-
-
-
-
Feature
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Programming Approach
-
API + App-oriented
-
Python Code
-
App-oriented
-
API-oriented
-
-
-
Supported LLMs
-
Rich Variety
-
Rich Variety
-
Rich Variety
-
OpenAI-only
-
-
-
RAG Engine
-
✅
-
✅
-
✅
-
✅
-
-
-
Agent
-
✅
-
✅
-
❌
-
✅
-
-
-
Workflow
-
✅
-
❌
-
✅
-
❌
-
-
-
Observability
-
✅
-
✅
-
❌
-
❌
-
-
-
Enterprise Feature (SSO/Access control)
-
✅
-
❌
-
❌
-
❌
-
-
-
Local Deployment
-
✅
-
✅
-
✅
-
❌
-
-
-
## Dify verwenden
- **Cloud **
diff --git a/README_ES.md b/README_ES.md
index bc3b25f2d1..c7a18dc675 100644
--- a/README_ES.md
+++ b/README_ES.md
@@ -79,74 +79,6 @@ Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiemp
**7. Backend como servicio**:
Todas las ofertas de Dify vienen con APIs correspondientes, por lo que podrías integrar Dify sin esfuerzo en tu propia lógica empresarial.
-## Comparación de características
-
-
-
-
Característica
-
Dify.AI
-
LangChain
-
Flowise
-
API de Asistentes de OpenAI
-
-
-
Enfoque de programación
-
API + orientado a la aplicación
-
Código Python
-
Orientado a la aplicación
-
Orientado a la API
-
-
-
LLMs admitidos
-
Gran variedad
-
Gran variedad
-
Gran variedad
-
Solo OpenAI
-
-
-
Motor RAG
-
✅
-
✅
-
✅
-
✅
-
-
-
Agente
-
✅
-
✅
-
❌
-
✅
-
-
-
Flujo de trabajo
-
✅
-
❌
-
✅
-
❌
-
-
-
Observabilidad
-
✅
-
✅
-
❌
-
❌
-
-
-
Característica empresarial (SSO/Control de acceso)
-
✅
-
❌
-
❌
-
❌
-
-
-
Implementación local
-
✅
-
✅
-
✅
-
❌
-
-
-
## Usando Dify
- **Nube **
diff --git a/README_FR.md b/README_FR.md
index 7521753100..316d50c929 100644
--- a/README_FR.md
+++ b/README_FR.md
@@ -79,74 +79,6 @@ Surveillez et analysez les journaux d'application et les performances au fil du
**7. Backend-as-a-Service** :
Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier.
-## Comparaison des fonctionnalités
-
-
-
## Difyの使用方法
- **クラウド **
diff --git a/README_KL.md b/README_KL.md
index 252a2b6db5..93da9a6140 100644
--- a/README_KL.md
+++ b/README_KL.md
@@ -79,74 +79,6 @@ Monitor and analyze application logs and performance over time. You could contin
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
-## Feature Comparison
-
-
-
-
Feature
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Programming Approach
-
API + App-oriented
-
Python Code
-
App-oriented
-
API-oriented
-
-
-
Supported LLMs
-
Rich Variety
-
Rich Variety
-
Rich Variety
-
OpenAI-only
-
-
-
RAG Engine
-
✅
-
✅
-
✅
-
✅
-
-
-
Agent
-
✅
-
✅
-
❌
-
✅
-
-
-
Workflow
-
✅
-
❌
-
✅
-
❌
-
-
-
Observability
-
✅
-
✅
-
❌
-
❌
-
-
-
Enterprise Feature (SSO/Access control)
-
✅
-
❌
-
❌
-
❌
-
-
-
Local Deployment
-
✅
-
✅
-
✅
-
❌
-
-
-
## Using Dify
- **Cloud **
diff --git a/README_KR.md b/README_KR.md
index 278e3f6c33..3b58339e12 100644
--- a/README_KR.md
+++ b/README_KR.md
@@ -73,74 +73,6 @@ LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에
**7. Backend-as-a-Service**:
Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다.
-## 기능 비교
-
-
-
-
기능
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
프로그래밍 접근 방식
-
API + 앱 중심
-
Python 코드
-
앱 중심
-
API 중심
-
-
-
지원되는 LLMs
-
다양한 종류
-
다양한 종류
-
다양한 종류
-
OpenAI 전용
-
-
-
RAG 엔진
-
✅
-
✅
-
✅
-
✅
-
-
-
에이전트
-
✅
-
✅
-
❌
-
✅
-
-
-
워크플로우
-
✅
-
❌
-
✅
-
❌
-
-
-
가시성
-
✅
-
✅
-
❌
-
❌
-
-
-
기업용 기능 (SSO/접근 제어)
-
✅
-
❌
-
❌
-
❌
-
-
-
로컬 배포
-
✅
-
✅
-
✅
-
❌
-
-
-
## Dify 사용하기
- **클라우드 **
diff --git a/README_PT.md b/README_PT.md
index 8bff880728..ec2e4245f6 100644
--- a/README_PT.md
+++ b/README_PT.md
@@ -79,74 +79,6 @@ Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo.
**7. Backend como Serviço**:
Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa.
-## Comparação de recursos
-
-
-
-
Recurso
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Abordagem de Programação
-
Orientada a API + Aplicativo
-
Código Python
-
Orientada a Aplicativo
-
Orientada a API
-
-
-
LLMs Suportados
-
Variedade Rica
-
Variedade Rica
-
Variedade Rica
-
Apenas OpenAI
-
-
-
RAG Engine
-
✅
-
✅
-
✅
-
✅
-
-
-
Agente
-
✅
-
✅
-
❌
-
✅
-
-
-
Workflow
-
✅
-
❌
-
✅
-
❌
-
-
-
Observabilidade
-
✅
-
✅
-
❌
-
❌
-
-
-
Recursos Empresariais (SSO/Controle de Acesso)
-
✅
-
❌
-
❌
-
❌
-
-
-
Implantação Local
-
✅
-
✅
-
✅
-
❌
-
-
-
## Usando o Dify
- **Nuvem **
diff --git a/README_SI.md b/README_SI.md
index be8c6320fb..c20dc3484f 100644
--- a/README_SI.md
+++ b/README_SI.md
@@ -103,74 +103,6 @@ Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozi
**7. Backend-as-a-Service**:
AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
-## Primerjava Funkcij
-
-
-
-
Funkcija
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Programski pristop
-
API + usmerjeno v aplikacije
-
Python koda
-
Usmerjeno v aplikacije
-
Usmerjeno v API
-
-
-
Podprti LLM-ji
-
Bogata izbira
-
Bogata izbira
-
Bogata izbira
-
Samo OpenAI
-
-
-
RAG pogon
-
✅
-
✅
-
✅
-
✅
-
-
-
Agent
-
✅
-
✅
-
❌
-
✅
-
-
-
Potek dela
-
✅
-
❌
-
✅
-
❌
-
-
-
Spremljanje
-
✅
-
✅
-
❌
-
❌
-
-
-
Funkcija za podjetja (SSO/nadzor dostopa)
-
✅
-
❌
-
❌
-
❌
-
-
-
Lokalna namestitev
-
✅
-
✅
-
✅
-
❌
-
-
-
## Uporaba Dify
- **Cloud **
diff --git a/README_TR.md b/README_TR.md
index e54b1f4589..510b112e68 100644
--- a/README_TR.md
+++ b/README_TR.md
@@ -74,74 +74,6 @@ Uygulama loglarını ve performans metriklerini zaman içinde izleme ve analiz e
**7. Hizmet Olarak Backend**:
Dify'ın tüm özellikleri ilgili API'lerle birlikte gelir, böylece Dify'ı kendi iş mantığınıza kolayca entegre edebilirsiniz.
-## Özellik karşılaştırması
-
-
-
## 使用 Dify
- **雲端服務 **
diff --git a/README_VI.md b/README_VI.md
index 8c5c333e8f..f161b20f9d 100644
--- a/README_VI.md
+++ b/README_VI.md
@@ -74,74 +74,6 @@ Giám sát và phân tích nhật ký và hiệu suất ứng dụng theo thời
**7. Backend-as-a-Service**:
Tất cả các dịch vụ của Dify đều đi kèm với các API tương ứng, vì vậy bạn có thể dễ dàng tích hợp Dify vào logic kinh doanh của riêng mình.
-## So sánh tính năng
-
-
-
-
Tính năng
-
Dify.AI
-
LangChain
-
Flowise
-
OpenAI Assistants API
-
-
-
Phương pháp lập trình
-
Hướng API + Ứng dụng
-
Mã Python
-
Hướng ứng dụng
-
Hướng API
-
-
-
LLMs được hỗ trợ
-
Đa dạng phong phú
-
Đa dạng phong phú
-
Đa dạng phong phú
-
Chỉ OpenAI
-
-
-
RAG Engine
-
✅
-
✅
-
✅
-
✅
-
-
-
Agent
-
✅
-
✅
-
❌
-
✅
-
-
-
Quy trình làm việc
-
✅
-
❌
-
✅
-
❌
-
-
-
Khả năng quan sát
-
✅
-
✅
-
❌
-
❌
-
-
-
Tính năng doanh nghiệp (SSO/Kiểm soát truy cập)
-
✅
-
❌
-
❌
-
❌
-
-
-
Triển khai cục bộ
-
✅
-
✅
-
✅
-
❌
-
-
-
## Sử dụng Dify
- **Cloud **
diff --git a/api/README.md b/api/README.md
index 5571fdd0fd..8309a0e69b 100644
--- a/api/README.md
+++ b/api/README.md
@@ -80,7 +80,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
-uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
+uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
@@ -97,8 +97,16 @@ uv run celery -A app.celery beat
uv sync --dev
```
-1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
+1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
- ```bash
- uv run -P api bash dev/pytest/pytest_all_tests.sh
+ ```cli
+ uv run --project api pytest # Run all tests
+ uv run --project api pytest tests/unit_tests/ # Unit tests only
+ uv run --project api pytest tests/integration_tests/ # Integration tests
+
+ # Code quality
+ ./dev/reformat # Run all formatters and linters
+ uv run --project api ruff check --fix ./ # Fix linting issues
+ uv run --project api ruff format ./ # Format code
+ uv run --project api mypy . # Type checking
```
diff --git a/api/app_factory.py b/api/app_factory.py
index 032d6b17fc..8a0417dd72 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -5,6 +5,8 @@ from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
+logger = logging.getLogger(__name__)
+
# ----------------------------
# Application Factory Function
@@ -32,7 +34,7 @@ def create_app() -> DifyApp:
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
- logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
+ logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
@@ -93,14 +95,14 @@ def initialize_extensions(app: DifyApp):
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
- logging.info("Skipped %s", short_name)
+ logger.info("Skipped %s", short_name)
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
- logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
+ logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app():
diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py
index 3466eea1f6..df9de825de 100644
--- a/api/controllers/common/fields.py
+++ b/api/controllers/common/fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from libs.helper import AppIconUrlField
@@ -10,6 +10,12 @@ parameters__system_parameters = {
"workflow_file_upload_limit": fields.Integer,
}
+
+def build_system_parameters_model(api_or_ns: Api | Namespace):
+ """Build the system parameters model for the API or Namespace."""
+ return api_or_ns.model("SystemParameters", parameters__system_parameters)
+
+
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
@@ -25,6 +31,14 @@ parameters_fields = {
"system_parameters": fields.Nested(parameters__system_parameters),
}
+
+def build_parameters_model(api_or_ns: Api | Namespace):
+ """Build the parameters model for the API or Namespace."""
+ copied_fields = parameters_fields.copy()
+ copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
+ return api_or_ns.model("Parameters", copied_fields)
+
+
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
@@ -41,3 +55,8 @@ site_fields = {
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
+
+
+def build_site_model(api_or_ns: Api | Namespace):
+ """Build the site model for the API or Namespace."""
+ return api_or_ns.model("Site", site_fields)
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 8a55197fb6..7e5c28200a 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index d7500c415c..401e88709a 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -1,8 +1,8 @@
-from typing import Any
+from typing import Any, Optional
-import flask_restful
+import flask_restx
from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with
+from flask_restx import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -40,7 +40,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
).scalar_one_or_none()
if resource is None:
- flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
+ flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
return resource
@@ -49,7 +49,7 @@ class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
- resource_model: Any = None
+ resource_model: Optional[Any] = None
resource_id_field: str | None = None
token_prefix: str | None = None
max_keys = 10
@@ -81,7 +81,7 @@ class BaseApiKeyListResource(Resource):
)
if current_key_count >= self.max_keys:
- flask_restful.abort(
+ flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
@@ -102,7 +102,7 @@ class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
- resource_model: Any = None
+ resource_model: Optional[Any] = None
resource_id_field: str | None = None
def delete(self, resource_id, api_key_id):
@@ -126,7 +126,7 @@ class BaseApiKeyResource(Resource):
)
if key is None:
- flask_restful.abort(404, message="API key not found")
+ flask_restx.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py
index c228743fa5..c6cb6f6e3a 100644
--- a/api/controllers/console/app/advanced_prompt_template.py
+++ b/api/controllers/console/app/advanced_prompt_template.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py
index d433415894..a964154207 100644
--- a/api/controllers/console/app/agent.py
+++ b/api/controllers/console/app/agent.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index 2caa908d4a..37d23ccd9f 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -2,7 +2,7 @@ from typing import Literal
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 1cc13d669c..a6eb86122d 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -2,7 +2,7 @@ import uuid
from typing import cast
from flask_login import current_user
-from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
+from flask_restx import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 9ffb94e9f9..aee93a8814 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -1,7 +1,7 @@
from typing import cast
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 665cf1aede..ea1869a587 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py
index ad94112f05..bd5e7d0924 100644
--- a/api/controllers/console/app/completion.py
+++ b/api/controllers/console/app/completion.py
@@ -2,7 +2,7 @@ import logging
import flask_login
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 6ddae6fad5..06f0218771 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -2,8 +2,8 @@ from datetime import datetime
import pytz # pip install pytz
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Resource, marshal_with, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
@@ -24,6 +24,8 @@ from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
+from services.conversation_service import ConversationService
+from services.errors.conversation import ConversationNotExistsError
class CompletionConversationApi(Resource):
@@ -46,7 +48,9 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
- query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
+ query = db.select(Conversation).where(
+ Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
+ )
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
@@ -119,18 +123,11 @@ class CompletionConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
- conversation = (
- db.session.query(Conversation)
- .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
- .first()
- )
-
- if not conversation:
+ try:
+ ConversationService.delete(app_model, conversation_id, current_user)
+ except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- conversation.is_deleted = True
- db.session.commit()
-
return {"result": "success"}, 204
@@ -171,7 +168,7 @@ class ChatConversationApi(Resource):
.subquery()
)
- query = db.select(Conversation).where(Conversation.app_id == app_model.id)
+ query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"
@@ -284,18 +281,11 @@ class ChatConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
- conversation = (
- db.session.query(Conversation)
- .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
- .first()
- )
-
- if not conversation:
+ try:
+ ConversationService.delete(app_model, conversation_id, current_user)
+ except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- conversation.is_deleted = True
- db.session.commit()
-
return {"result": "success"}, 204
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index d49f433ba1..5ca4c33f87 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 57dc1267d5..497fd53df7 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.app.error import (
diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py
index 2344fd5acb..541803e539 100644
--- a/api/controllers/console/app/mcp_server.py
+++ b/api/controllers/console/app/mcp_server.py
@@ -2,7 +2,7 @@ import json
from enum import StrEnum
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index 680ac4a64c..57cc825fe9 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -1,8 +1,8 @@
import logging
from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Resource, fields, marshal_with, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py
index 029138fb6b..52ff9b923d 100644
--- a/api/controllers/console/app/model_config.py
+++ b/api/controllers/console/app/model_config.py
@@ -3,7 +3,7 @@ from typing import cast
from flask import request
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py
index 978c02412c..74c2867c2f 100644
--- a/api/controllers/console/app/ops_trace.py
+++ b/api/controllers/console/app/ops_trace.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index 03418f1dd2..778ce92da6 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index 343b7acd7b..27e405af38 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -5,7 +5,7 @@ import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index c58301b300..8dcffb1666 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import cast
from flask import abort, request
-from flask_restful import Resource, inputs, marshal_with, reqparse
+from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index 310146a5e7..8d8cdc93cf 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -1,6 +1,6 @@
from dateutil.parser import isoparse
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Resource, marshal_with, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from controllers.console import api
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index 414c07ef50..4e625db24d 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response
-from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 9099700213..dccbfd8648 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,8 +1,8 @@
from typing import cast
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Resource, marshal_with, reqparse
+from flask_restx.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py
index 7f80afd83b..7cef175c14 100644
--- a/api/controllers/console/app/workflow_statistic.py
+++ b/api/controllers/console/app/workflow_statistic.py
@@ -5,7 +5,7 @@ import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index 2562fb5eb8..e82e403ec2 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from constants.languages import supported_language
from controllers.console import api
diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py
index b8c3c8f012..796e6916cc 100644
--- a/api/controllers/console/auth/data_source_bearer_auth.py
+++ b/api/controllers/console/auth/data_source_bearer_auth.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index 4940b48754..d4cf20549a 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -3,7 +3,7 @@ import logging
import requests
from flask import current_app, redirect, request
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 3bbe3177fc..ede0696854 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -2,7 +2,7 @@ import base64
import secrets
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 5f2a24322d..a5ad6a1cd7 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -2,7 +2,7 @@ from typing import cast
import flask_login
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
import services
from configs import dify_config
@@ -221,7 +221,7 @@ class EmailCodeLoginApi(Resource):
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
- return NotAllowedCreateWorkspace()
+ raise NotAllowedCreateWorkspace()
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 4a6cb99390..3c76394cf9 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -3,7 +3,7 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
-from flask_restful import Resource
+from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index 4b0c82ae6c..8ebb745a60 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py
index 9679632ac7..4bc073f679 100644
--- a/api/controllers/console/billing/compliance.py
+++ b/api/controllers/console/billing/compliance.py
@@ -1,6 +1,6 @@
from flask import request
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 39f8ab5787..6083a53bec 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -2,7 +2,7 @@ import json
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 2befd2a651..a23536f82e 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -1,7 +1,7 @@
-import flask_restful
+import flask_restx
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -589,7 +589,7 @@ class DatasetApiKeyApi(Resource):
)
if current_key_count >= self.max_keys:
- flask_restful.abort(
+ flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
@@ -629,7 +629,7 @@ class DatasetApiDeleteApi(Resource):
)
if key is None:
- flask_restful.abort(404, message="API key not found")
+ flask_restx.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 413b018baa..f823ed603b 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -4,7 +4,7 @@ from typing import Literal, cast
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_restx import Resource, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 8c429044d7..463fd2d7ec 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -2,7 +2,7 @@ import uuid
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal, reqparse
+from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -584,7 +584,12 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
- .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
+ .where(
+ ChildChunk.id == str(child_chunk_id),
+ ChildChunk.tenant_id == current_user.current_tenant_id,
+ ChildChunk.segment_id == segment.id,
+ ChildChunk.document_id == document_id,
+ )
.first()
)
if not child_chunk:
@@ -633,7 +638,12 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
- .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
+ .where(
+ ChildChunk.id == str(child_chunk_id),
+ ChildChunk.tenant_id == current_user.current_tenant_id,
+ ChildChunk.segment_id == segment.id,
+ ChildChunk.document_id == document_id,
+ )
.first()
)
if not child_chunk:
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index cf9081e154..043f39f623 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -1,6 +1,6 @@
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal, reqparse
+from flask_restx import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index fba5d4c0f3..2ad192571b 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index 3b4c076863..304674db5f 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,7 +1,7 @@
import logging
from flask_login import current_user
-from flask_restful import marshal, reqparse
+from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py
index 1b5570285d..6aa309f930 100644
--- a/api/controllers/console/datasets/metadata.py
+++ b/api/controllers/console/datasets/metadata.py
@@ -1,7 +1,7 @@
from typing import Literal
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py
index 2afdaf7f2b..617dbcaff2 100644
--- a/api/controllers/console/datasets/upload_file.py
+++ b/api/controllers/console/datasets/upload_file.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index fcdc91ec67..bdaa268462 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index d564a00a76..2a4d5be82f 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -65,7 +65,7 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource):
def post(self, installed_app):
- from flask_restful import reqparse
+ from flask_restx import reqparse
app_model = installed_app.app
try:
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 4842fefc57..b444a2a197 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,7 +1,7 @@
import logging
from flask_login import current_user
-from flask_restful import reqparse
+from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index d7c161cc6d..a8d46954b5 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,6 +1,6 @@
from flask_login import current_user
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import marshal_with, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index ad62bd6e08..3ccedd654b 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -3,7 +3,7 @@ from typing import Any
from flask import request
from flask_login import current_user
-from flask_restful import Resource, inputs, marshal_with, reqparse
+from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index de95a9e7b0..6df3bca762 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,8 +1,8 @@
import logging
from flask_login import current_user
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import marshal_with, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console.app.error import (
diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py
index a1280d91d1..c368744759 100644
--- a/api/controllers/console/explore/parameter.py
+++ b/api/controllers/console/explore/parameter.py
@@ -1,4 +1,4 @@
-from flask_restful import marshal_with
+from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import api
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index ce85f495aa..62f9350b71 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 339e7007a0..5353dbcad5 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,6 +1,6 @@
from flask_login import current_user
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import fields, marshal_with, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py
index 3f625e6609..3d872fc1fc 100644
--- a/api/controllers/console/explore/workflow.py
+++ b/api/controllers/console/explore/workflow.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse
+from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError
from controllers.console.app.error import (
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index de97fb149e..e86103184a 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py
index 07a241ef86..e157041c35 100644
--- a/api/controllers/console/extension.py
+++ b/api/controllers/console/extension.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 70ab4ff865..6236832d39 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from libs.login import login_required
from services.feature_service import FeatureService
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index a87d270e9c..101a49a32e 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -2,7 +2,7 @@ from typing import Literal
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal_with
+from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py
index b19e331d2e..2a37b1708a 100644
--- a/api/controllers/console/init_validate.py
+++ b/api/controllers/console/init_validate.py
@@ -1,7 +1,7 @@
import os
from flask import session
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index cd28cc946e..1a53a2347e 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.console import api
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index c356113c40..73014cfc97 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -3,7 +3,7 @@ from typing import cast
import httpx
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index e1f19a87a3..8e230496f0 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index cb5dedca21..c45e7dbb26 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -1,11 +1,11 @@
from flask import request
from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
-from fields.tag_fields import tag_fields
+from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
from models.model import Tag
from services.tag_service import TagService
@@ -21,7 +21,7 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(tag_fields)
+ @marshal_with(dataset_tag_fields)
def get(self):
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 894785abc8..96cf627b65 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -2,7 +2,7 @@ import json
import logging
import requests
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from packaging import version
from configs import dify_config
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 3f6d6bf54f..5b2828dbab 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -3,7 +3,7 @@ from datetime import datetime
import pytz
from flask import request
from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py
index 88c37767e3..08bab6fcb5 100644
--- a/api/controllers/console/workspace/agent_providers.py
+++ b/api/controllers/console/workspace/agent_providers.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py
index eb53dcb16e..96e873d42b 100644
--- a/api/controllers/console/workspace/endpoint.py
+++ b/api/controllers/console/workspace/endpoint.py
@@ -1,5 +1,5 @@
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index b4eb5e246b..2a54511bf0 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index f7424923b9..f018fada3a 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -2,7 +2,7 @@ from urllib import parse
from flask import request
from flask_login import current_user
-from flask_restful import Resource, abort, marshal_with, reqparse
+from flask_restx import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py
index ff0fcbda6e..281783b3d7 100644
--- a/api/controllers/console/workspace/model_providers.py
+++ b/api/controllers/console/workspace/model_providers.py
@@ -2,7 +2,7 @@ import io
from flask import send_file
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 514d1084c4..b8dddb91dd 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,7 +1,7 @@
import logging
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index 09846d5c94..fd5421fa64 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -2,7 +2,7 @@ import io
from flask import request, send_file
from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 8c8b73b45d..854ba7ac45 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -3,7 +3,7 @@ from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
-from flask_restful import (
+from flask_restx import (
Resource,
reqparse,
)
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index f4f0078da7..fb89f6bbbd 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -2,7 +2,7 @@ import logging
from flask import request
from flask_login import current_user
-from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index d862dac373..d3fd1d52e5 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -1,3 +1,4 @@
+import contextlib
import json
import os
import time
@@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
- try:
+ with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@@ -187,8 +188,7 @@ def cloud_utm_record(view):
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
- except Exception as e:
- pass
+
return view(*args, **kwargs)
return decorated
diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py
index d4c3245708..282a181997 100644
--- a/api/controllers/files/__init__.py
+++ b/api/controllers/files/__init__.py
@@ -1,9 +1,20 @@
from flask import Blueprint
+from flask_restx import Namespace
from libs.external_api import ExternalApi
-bp = Blueprint("files", __name__)
-api = ExternalApi(bp)
+bp = Blueprint("files", __name__, url_prefix="/files")
+api = ExternalApi(
+ bp,
+ version="1.0",
+ title="Files API",
+ description="API for file operations including upload and preview",
+ doc="/docs", # Enable Swagger UI at /files/docs
+)
+
+files_ns = Namespace("files", description="File operations")
from . import image_preview, tool_files, upload
+
+api.add_namespace(files_ns)
diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py
index 91f7b27d1d..48baac6556 100644
--- a/api/controllers/files/image_preview.py
+++ b/api/controllers/files/image_preview.py
@@ -1,16 +1,17 @@
from urllib.parse import quote
from flask import Response, request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
-from controllers.files import api
+from controllers.files import files_ns
from services.account_service import TenantService
from services.file_service import FileService
+@files_ns.route("//image-preview")
class ImagePreviewApi(Resource):
"""
Deprecated
@@ -39,6 +40,7 @@ class ImagePreviewApi(Resource):
return Response(generator, mimetype=mimetype)
+@files_ns.route("//file-preview")
class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
@@ -94,6 +96,7 @@ class FilePreviewApi(Resource):
return response
+@files_ns.route("/workspaces//webapp-logo")
class WorkspaceWebappLogoApi(Resource):
def get(self, workspace_id):
workspace_id = str(workspace_id)
@@ -112,8 +115,3 @@ class WorkspaceWebappLogoApi(Resource):
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
-
-
-api.add_resource(ImagePreviewApi, "/files//image-preview")
-api.add_resource(FilePreviewApi, "/files//file-preview")
-api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo")
diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py
index d9c4e50511..faa9b733c2 100644
--- a/api/controllers/files/tool_files.py
+++ b/api/controllers/files/tool_files.py
@@ -1,17 +1,18 @@
from urllib.parse import quote
from flask import Response
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
-from controllers.files import api
+from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
-class ToolFilePreviewApi(Resource):
+@files_ns.route("/tools/.")
+class ToolFileApi(Resource):
def get(self, file_id, extension):
file_id = str(file_id)
@@ -52,6 +53,3 @@ class ToolFilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
return response
-
-
-api.add_resource(ToolFilePreviewApi, "/files/tools/.")
diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py
index bcc72d131c..7a2b3b0428 100644
--- a/api/controllers/files/upload.py
+++ b/api/controllers/files/upload.py
@@ -1,7 +1,9 @@
from mimetypes import guess_extension
+from typing import Optional
-from flask import request
-from flask_restful import Resource, marshal_with
+from flask_restx import Resource, reqparse
+from flask_restx.api import HTTPStatus
+from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
@@ -10,39 +12,76 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
-from controllers.files import api
+from controllers.files import files_ns
from controllers.inner_api.plugin.wraps import get_user
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
-from fields.file_fields import file_fields
+from fields.file_fields import build_file_model
+
+# Define parser for both documentation and validation
+upload_parser = reqparse.RequestParser()
+upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
+upload_parser.add_argument(
+ "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
+)
+upload_parser.add_argument(
+ "nonce", type=str, required=True, location="args", help="Random string for signature verification"
+)
+upload_parser.add_argument(
+ "sign", type=str, required=True, location="args", help="HMAC signature for request validation"
+)
+upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
+upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
+@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@setup_required
- @marshal_with(file_fields)
+ @files_ns.expect(upload_parser)
+ @files_ns.doc("upload_plugin_file")
+ @files_ns.doc(description="Upload a file for plugin usage with signature verification")
+ @files_ns.doc(
+ responses={
+ 201: "File uploaded successfully",
+ 400: "Invalid request parameters",
+ 403: "Forbidden - Invalid signature or missing parameters",
+ 413: "File too large",
+ 415: "Unsupported file type",
+ }
+ )
+ @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
def post(self):
- # get file from request
- file = request.files["file"]
+ """Upload a file for plugin usage.
- timestamp = request.args.get("timestamp")
- nonce = request.args.get("nonce")
- sign = request.args.get("sign")
- tenant_id = request.args.get("tenant_id")
- if not tenant_id:
- raise Forbidden("Invalid request.")
+ Accepts a file upload with signature verification for security.
+ The file must be accompanied by valid timestamp, nonce, and signature parameters.
- user_id = request.args.get("user_id")
+ Returns:
+ dict: File metadata including ID, URLs, and properties
+ int: HTTP status code (201 for success)
+
+ Raises:
+ Forbidden: Invalid signature or missing required parameters
+ FileTooLargeError: File exceeds size limit
+ UnsupportedFileTypeError: File type not supported
+ """
+ # Parse and validate all arguments
+ args = upload_parser.parse_args()
+
+ file: FileStorage = args["file"]
+ timestamp: str = args["timestamp"]
+ nonce: str = args["nonce"]
+ sign: str = args["sign"]
+ tenant_id: str = args["tenant_id"]
+ user_id: Optional[str] = args.get("user_id")
user = get_user(tenant_id, user_id)
- filename = file.filename
- mimetype = file.mimetype
+ filename: Optional[str] = file.filename
+ mimetype: Optional[str] = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
- if not timestamp or not nonce or not sign:
- raise Forbidden("Invalid request.")
-
if not verify_plugin_file_signature(
filename=filename,
mimetype=mimetype,
@@ -88,6 +127,3 @@ class PluginUploadFileApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
-
-
-api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py
index 7b96f88f51..80bbc360de 100644
--- a/api/controllers/inner_api/mail.py
+++ b/api/controllers/inner_api/mail.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py
index 5dfe41eb6b..9b8d9457f0 100644
--- a/api/controllers/inner_api/plugin/plugin.py
+++ b/api/controllers/inner_api/plugin/plugin.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py
index b533614d4d..89b4ac7506 100644
--- a/api/controllers/inner_api/plugin/wraps.py
+++ b/api/controllers/inner_api/plugin/wraps.py
@@ -4,7 +4,7 @@ from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
-from flask_restful import reqparse
+from flask_restx import reqparse
from pydantic import BaseModel
from sqlalchemy.orm import Session
diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py
index 77568b75f1..1c26416080 100644
--- a/api/controllers/inner_api/workspace/workspace.py
+++ b/api/controllers/inner_api/workspace/workspace.py
@@ -1,6 +1,6 @@
import json
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py
index 1b3e0a5621..1f5dae74e8 100644
--- a/api/controllers/mcp/__init__.py
+++ b/api/controllers/mcp/__init__.py
@@ -1,8 +1,20 @@
from flask import Blueprint
+from flask_restx import Namespace
from libs.external_api import ExternalApi
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
-api = ExternalApi(bp)
+
+api = ExternalApi(
+ bp,
+ version="1.0",
+ title="MCP API",
+ description="API for Model Context Protocol operations",
+ doc="/docs", # Enable Swagger UI at /mcp/docs
+)
+
+mcp_ns = Namespace("mcp", description="MCP operations")
from . import mcp
+
+api.add_namespace(mcp_ns)
diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py
index 87d678796f..fc19749011 100644
--- a/api/controllers/mcp/mcp.py
+++ b/api/controllers/mcp/mcp.py
@@ -1,8 +1,10 @@
-from flask_restful import Resource, reqparse
+from typing import Optional, Union
+
+from flask_restx import Resource, reqparse
from pydantic import ValidationError
from controllers.console.app.mcp_server import AppMCPServerStatus
-from controllers.mcp import api
+from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
@@ -13,22 +15,58 @@ from libs import helper
from models.model import App, AppMCPServer, AppMode
+def int_or_str(value):
+ """Validate that a value is either an integer or string."""
+ if isinstance(value, (int, str)):
+ return value
+ else:
+ return None
+
+
+# Define parser for both documentation and validation
+mcp_request_parser = reqparse.RequestParser()
+mcp_request_parser.add_argument(
+ "jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')"
+)
+mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
+mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
+mcp_request_parser.add_argument(
+ "id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses"
+)
+
+
+@mcp_ns.route("/server//mcp")
class MCPAppApi(Resource):
- def post(self, server_code):
- def int_or_str(value):
- if isinstance(value, (int, str)):
- return value
- else:
- return None
+ @mcp_ns.expect(mcp_request_parser)
+ @mcp_ns.doc("handle_mcp_request")
+ @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
+ @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
+ @mcp_ns.doc(
+ responses={
+ 200: "MCP response successfully processed",
+ 400: "Invalid MCP request or parameters",
+ 404: "Server or app not found",
+ }
+ )
+ def post(self, server_code: str):
+ """Handle MCP requests for a specific server.
- parser = reqparse.RequestParser()
- parser.add_argument("jsonrpc", type=str, required=True, location="json")
- parser.add_argument("method", type=str, required=True, location="json")
- parser.add_argument("params", type=dict, required=False, location="json")
- parser.add_argument("id", type=int_or_str, required=False, location="json")
- args = parser.parse_args()
+ Processes JSON-RPC formatted requests according to the Model Context Protocol specification.
+ Validates the server status and associated app before processing the request.
- request_id = args.get("id")
+ Args:
+ server_code: Unique identifier for the MCP server
+
+ Returns:
+ dict: JSON-RPC response from the MCP handler
+
+ Raises:
+ ValidationError: Invalid request format or parameters
+ """
+ # Parse and validate all arguments
+ args = mcp_request_parser.parse_args()
+
+ request_id: Optional[Union[int, str]] = args.get("id")
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
@@ -99,6 +137,3 @@ class MCPAppApi(Resource):
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
-
-
-api.add_resource(MCPAppApi, "/server//mcp")
diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py
index b26f29d98d..aaa3c8f9a1 100644
--- a/api/controllers/service_api/__init__.py
+++ b/api/controllers/service_api/__init__.py
@@ -1,11 +1,23 @@
from flask import Blueprint
+from flask_restx import Namespace
from libs.external_api import ExternalApi
bp = Blueprint("service_api", __name__, url_prefix="/v1")
-api = ExternalApi(bp)
+
+api = ExternalApi(
+ bp,
+ version="1.0",
+ title="Service API",
+ description="API for application services",
+ doc="/docs", # Enable Swagger UI at /v1/docs
+)
+
+service_api_ns = Namespace("service_api", description="Service operations")
from . import index
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models
+
+api.add_namespace(service_api_ns)
diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py
index 23446bb702..6bc94af8c1 100644
--- a/api/controllers/service_api/app/annotation.py
+++ b/api/controllers/service_api/app/annotation.py
@@ -1,28 +1,51 @@
from typing import Literal
from flask import request
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_restx import Api, Namespace, Resource, fields, reqparse
+from flask_restx.api import HTTPStatus
from werkzeug.exceptions import Forbidden
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
-from fields.annotation_fields import (
- annotation_fields,
-)
+from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models.model import App
from services.annotation_service import AppAnnotationService
+# Define parsers for annotation API
+annotation_create_parser = reqparse.RequestParser()
+annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question")
+annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
+annotation_reply_action_parser = reqparse.RequestParser()
+annotation_reply_action_parser.add_argument(
+ "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
+)
+annotation_reply_action_parser.add_argument(
+ "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name"
+)
+annotation_reply_action_parser.add_argument(
+ "embedding_model_name", required=True, type=str, location="json", help="Embedding model name"
+)
+
+
+@service_api_ns.route("/apps/annotation-reply/")
class AnnotationReplyActionApi(Resource):
+ @service_api_ns.expect(annotation_reply_action_parser)
+ @service_api_ns.doc("annotation_reply_action")
+ @service_api_ns.doc(description="Enable or disable annotation reply feature")
+ @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Action completed successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
- parser = reqparse.RequestParser()
- parser.add_argument("score_threshold", required=True, type=float, location="json")
- parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
- parser.add_argument("embedding_model_name", required=True, type=str, location="json")
- args = parser.parse_args()
+ """Enable or disable annotation reply feature."""
+ args = annotation_reply_action_parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
@@ -30,9 +53,21 @@ class AnnotationReplyActionApi(Resource):
return result, 200
+@service_api_ns.route("/apps/annotation-reply//status/")
class AnnotationReplyActionStatusApi(Resource):
+ @service_api_ns.doc("get_annotation_reply_action_status")
+ @service_api_ns.doc(description="Get the status of an annotation reply action job")
+ @service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Job status retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Job not found",
+ }
+ )
@validate_app_token
def get(self, app_model: App, job_id, action):
+ """Get the status of an annotation reply action job."""
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@@ -48,60 +83,111 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
+# Define annotation list response model
+annotation_list_fields = {
+ "data": fields.List(fields.Nested(annotation_fields)),
+ "has_more": fields.Boolean,
+ "limit": fields.Integer,
+ "total": fields.Integer,
+ "page": fields.Integer,
+}
+
+
+def build_annotation_list_model(api_or_ns: Api | Namespace):
+ """Build the annotation list model for the API or Namespace."""
+ copied_annotation_list_fields = annotation_list_fields.copy()
+ copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
+ return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
+
+
+@service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource):
+ @service_api_ns.doc("list_annotations")
+ @service_api_ns.doc(description="List annotations for the application")
+ @service_api_ns.doc(
+ responses={
+ 200: "Annotations retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_app_token
+ @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
+ """List annotations for the application."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
- response = {
- "data": marshal(annotation_list, annotation_fields),
+ return {
+ "data": annotation_list,
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
- return response, 200
+ @service_api_ns.expect(annotation_create_parser)
+ @service_api_ns.doc("create_annotation")
+ @service_api_ns.doc(description="Create a new annotation")
+ @service_api_ns.doc(
+ responses={
+ 201: "Annotation created successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_app_token
- @marshal_with(annotation_fields)
+ @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
- parser = reqparse.RequestParser()
- parser.add_argument("question", required=True, type=str, location="json")
- parser.add_argument("answer", required=True, type=str, location="json")
- args = parser.parse_args()
+ """Create a new annotation."""
+ args = annotation_create_parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
- return annotation
+ return annotation, 201
+@service_api_ns.route("/apps/annotations/")
class AnnotationUpdateDeleteApi(Resource):
+ @service_api_ns.expect(annotation_create_parser)
+ @service_api_ns.doc("update_annotation")
+ @service_api_ns.doc(description="Update an existing annotation")
+ @service_api_ns.doc(params={"annotation_id": "Annotation ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Annotation updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Annotation not found",
+ }
+ )
@validate_app_token
- @marshal_with(annotation_fields)
+ @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id):
+ """Update an existing annotation."""
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
- parser = reqparse.RequestParser()
- parser.add_argument("question", required=True, type=str, location="json")
- parser.add_argument("answer", required=True, type=str, location="json")
- args = parser.parse_args()
+ args = annotation_create_parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
+ @service_api_ns.doc("delete_annotation")
+ @service_api_ns.doc(description="Delete an annotation")
+ @service_api_ns.doc(params={"annotation_id": "Annotation ID"})
+ @service_api_ns.doc(
+ responses={
+ 204: "Annotation deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Annotation not found",
+ }
+ )
@validate_app_token
def delete(self, app_model: App, annotation_id):
+ """Delete an annotation."""
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 204
-
-
-api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/")
-api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply//status/")
-api.add_resource(AnnotationListApi, "/apps/annotations")
-api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/")
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index 89222d5e83..2dbeed1d68 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,7 +1,7 @@
-from flask_restful import Resource, marshal_with
+from flask_restx import Resource
-from controllers.common import fields
-from controllers.service_api import api
+from controllers.common.fields import build_parameters_model
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@@ -9,13 +9,26 @@ from models.model import App, AppMode
from services.app_service import AppService
+@service_api_ns.route("/parameters")
class AppParameterApi(Resource):
"""Resource for app variables."""
+ @service_api_ns.doc("get_app_parameters")
+ @service_api_ns.doc(description="Retrieve application input parameters and configuration")
+ @service_api_ns.doc(
+ responses={
+ 200: "Parameters retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Application not found",
+ }
+ )
@validate_app_token
- @marshal_with(fields.parameters_fields)
+ @service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
- """Retrieve app parameters."""
+ """Retrieve app parameters.
+
+ Returns the input form parameters and configuration for the application.
+ """
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
@@ -35,17 +48,43 @@ class AppParameterApi(Resource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+@service_api_ns.route("/meta")
class AppMetaApi(Resource):
+ @service_api_ns.doc("get_app_meta")
+ @service_api_ns.doc(description="Get application metadata")
+ @service_api_ns.doc(
+ responses={
+ 200: "Metadata retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Application not found",
+ }
+ )
@validate_app_token
def get(self, app_model: App):
- """Get app meta"""
+ """Get app metadata.
+
+ Returns metadata about the application including configuration and settings.
+ """
return AppService().get_app_meta(app_model)
+@service_api_ns.route("/info")
class AppInfoApi(Resource):
+ @service_api_ns.doc("get_app_info")
+ @service_api_ns.doc(description="Get basic application information")
+ @service_api_ns.doc(
+ responses={
+ 200: "Application info retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Application not found",
+ }
+ )
@validate_app_token
def get(self, app_model: App):
- """Get app information"""
+ """Get app information.
+
+ Returns basic information about the application including name, description, tags, and mode.
+ """
tags = [tag.name for tag in app_model.tags]
return {
"name": app_model.name,
@@ -54,8 +93,3 @@ class AppInfoApi(Resource):
"mode": app_model.mode,
"author_name": app_model.author_name,
}
-
-
-api.add_resource(AppParameterApi, "/parameters")
-api.add_resource(AppMetaApi, "/meta")
-api.add_resource(AppInfoApi, "/info")
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index 848863cf1b..61b3020a5f 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -1,11 +1,11 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -30,9 +30,26 @@ from services.errors.audio import (
)
+@service_api_ns.route("/audio-to-text")
class AudioApi(Resource):
+ @service_api_ns.doc("audio_to_text")
+ @service_api_ns.doc(description="Convert audio to text using speech-to-text")
+ @service_api_ns.doc(
+ responses={
+ 200: "Audio successfully transcribed",
+ 400: "Bad request - no audio or invalid audio",
+ 401: "Unauthorized - invalid API token",
+ 413: "Audio file too large",
+ 415: "Unsupported audio type",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
+ """Convert audio to text using speech-to-text.
+
+ Accepts an audio file upload and returns the transcribed text.
+ """
file = request.files["file"]
try:
@@ -65,16 +82,35 @@ class AudioApi(Resource):
raise InternalServerError()
+# Define parser for text-to-audio API
+text_to_audio_parser = reqparse.RequestParser()
+text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
+text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
+text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio")
+text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
+
+
+@service_api_ns.route("/text-to-audio")
class TextApi(Resource):
+ @service_api_ns.expect(text_to_audio_parser)
+ @service_api_ns.doc("text_to_audio")
+ @service_api_ns.doc(description="Convert text to audio using text-to-speech")
+ @service_api_ns.doc(
+ responses={
+ 200: "Text successfully converted to audio",
+ 400: "Bad request - invalid parameters",
+ 401: "Unauthorized - invalid API token",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
+ """Convert text to audio using text-to-speech.
+
+ Converts the provided text to audio using the specified voice.
+ """
try:
- parser = reqparse.RequestParser()
- parser.add_argument("message_id", type=str, required=False, location="json")
- parser.add_argument("voice", type=str, location="json")
- parser.add_argument("text", type=str, location="json")
- parser.add_argument("streaming", type=bool, location="json")
- args = parser.parse_args()
+ args = text_to_audio_parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
@@ -108,7 +144,3 @@ class TextApi(Resource):
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
-
-
-api.add_resource(AudioApi, "/audio-to-text")
-api.add_resource(TextApi, "/text-to-audio")
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index ea57f04850..dddb75d593 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -1,11 +1,11 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -33,21 +33,68 @@ from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
+# Define parser for completion API
+completion_parser = reqparse.RequestParser()
+completion_parser.add_argument(
+ "inputs", type=dict, required=True, location="json", help="Input parameters for completion"
+)
+completion_parser.add_argument("query", type=str, location="json", default="", help="The query string")
+completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
+completion_parser.add_argument(
+ "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
+)
+completion_parser.add_argument(
+ "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
+)
+# Define parser for chat API
+chat_parser = reqparse.RequestParser()
+chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
+chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query")
+chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
+chat_parser.add_argument(
+ "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
+)
+chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
+chat_parser.add_argument(
+ "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
+)
+chat_parser.add_argument(
+ "auto_generate_name",
+ type=bool,
+ required=False,
+ default=True,
+ location="json",
+ help="Auto generate conversation name",
+)
+chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
+
+
+@service_api_ns.route("/completion-messages")
class CompletionApi(Resource):
+ @service_api_ns.expect(completion_parser)
+ @service_api_ns.doc("create_completion")
+ @service_api_ns.doc(description="Create a completion for the given prompt")
+ @service_api_ns.doc(
+ responses={
+ 200: "Completion created successfully",
+ 400: "Bad request - invalid parameters",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation not found",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
+ """Create a completion for the given prompt.
+
+ This endpoint generates a completion based on the provided inputs and query.
+ Supports both blocking and streaming response modes.
+ """
if app_model.mode != "completion":
raise AppUnavailableError()
- parser = reqparse.RequestParser()
- parser.add_argument("inputs", type=dict, required=True, location="json")
- parser.add_argument("query", type=str, location="json", default="")
- parser.add_argument("files", type=list, required=False, location="json")
- parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
-
- args = parser.parse_args()
+ args = completion_parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
@@ -88,9 +135,21 @@ class CompletionApi(Resource):
raise InternalServerError()
+@service_api_ns.route("/completion-messages//stop")
class CompletionStopApi(Resource):
+ @service_api_ns.doc("stop_completion")
+ @service_api_ns.doc(description="Stop a running completion task")
+ @service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Task stopped successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Task not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
- def post(self, app_model: App, end_user: EndUser, task_id):
+ def post(self, app_model: App, end_user: EndUser, task_id: str):
+ """Stop a running completion task."""
if app_model.mode != "completion":
raise AppUnavailableError()
@@ -99,23 +158,33 @@ class CompletionStopApi(Resource):
return {"result": "success"}, 200
+@service_api_ns.route("/chat-messages")
class ChatApi(Resource):
+ @service_api_ns.expect(chat_parser)
+ @service_api_ns.doc("create_chat_message")
+ @service_api_ns.doc(description="Send a message in a chat conversation")
+ @service_api_ns.doc(
+ responses={
+ 200: "Message sent successfully",
+ 400: "Bad request - invalid parameters or workflow issues",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation or workflow not found",
+ 429: "Rate limit exceeded",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
+ """Send a message in a chat conversation.
+
+ This endpoint handles chat messages for chat, agent chat, and advanced chat applications.
+ Supports conversation management and both blocking and streaming response modes.
+ """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = reqparse.RequestParser()
- parser.add_argument("inputs", type=dict, required=True, location="json")
- parser.add_argument("query", type=str, required=True, location="json")
- parser.add_argument("files", type=list, required=False, location="json")
- parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- parser.add_argument("conversation_id", type=uuid_value, location="json")
- parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
- parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
- parser.add_argument("workflow_id", type=str, required=False, location="json")
- args = parser.parse_args()
+ args = chat_parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@@ -159,9 +228,21 @@ class ChatApi(Resource):
raise InternalServerError()
+@service_api_ns.route("/chat-messages//stop")
class ChatStopApi(Resource):
+ @service_api_ns.doc("stop_chat_message")
+ @service_api_ns.doc(description="Stop a running chat message generation")
+ @service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Task stopped successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Task not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
- def post(self, app_model: App, end_user: EndUser, task_id):
+ def post(self, app_model: App, end_user: EndUser, task_id: str):
+ """Stop a running chat message generation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@@ -169,9 +250,3 @@ class ChatStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {"result": "success"}, 200
-
-
-api.add_resource(CompletionApi, "/completion-messages")
-api.add_resource(CompletionStopApi, "/completion-messages//stop")
-api.add_resource(ChatApi, "/chat-messages")
-api.add_resource(ChatStopApi, "/chat-messages//stop")
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index 073307ac4a..4860bf3a79 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -1,48 +1,97 @@
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Resource, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
import services
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
- conversation_delete_fields,
- conversation_infinite_scroll_pagination_fields,
- simple_conversation_fields,
+ build_conversation_delete_model,
+ build_conversation_infinite_scroll_pagination_model,
+ build_simple_conversation_model,
)
from fields.conversation_variable_fields import (
- conversation_variable_fields,
- conversation_variable_infinite_scroll_pagination_fields,
+ build_conversation_variable_infinite_scroll_pagination_model,
+ build_conversation_variable_model,
)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
+# Define parsers for conversation APIs
+conversation_list_parser = reqparse.RequestParser()
+conversation_list_parser.add_argument(
+ "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination"
+)
+conversation_list_parser.add_argument(
+ "limit",
+ type=int_range(1, 100),
+ required=False,
+ default=20,
+ location="args",
+ help="Number of conversations to return",
+)
+conversation_list_parser.add_argument(
+ "sort_by",
+ type=str,
+ choices=["created_at", "-created_at", "updated_at", "-updated_at"],
+ required=False,
+ default="-updated_at",
+ location="args",
+ help="Sort order for conversations",
+)
+conversation_rename_parser = reqparse.RequestParser()
+conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name")
+conversation_rename_parser.add_argument(
+ "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name"
+)
+
+conversation_variables_parser = reqparse.RequestParser()
+conversation_variables_parser.add_argument(
+ "last_id", type=uuid_value, location="args", help="Last variable ID for pagination"
+)
+conversation_variables_parser.add_argument(
+ "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return"
+)
+
+conversation_variable_update_parser = reqparse.RequestParser()
+# using lambda is for passing the already-typed value without modification
+# if no lambda, it will be converted to string
+# the string cannot be converted using json.loads
+conversation_variable_update_parser.add_argument(
+ "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable"
+)
+
+
+@service_api_ns.route("/conversations")
class ConversationApi(Resource):
+ @service_api_ns.expect(conversation_list_parser)
+ @service_api_ns.doc("list_conversations")
+ @service_api_ns.doc(description="List all conversations for the current user")
+ @service_api_ns.doc(
+ responses={
+ 200: "Conversations retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Last conversation not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @marshal_with(conversation_infinite_scroll_pagination_fields)
+ @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
+ """List all conversations for the current user.
+
+ Supports pagination using last_id and limit parameters.
+ """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = reqparse.RequestParser()
- parser.add_argument("last_id", type=uuid_value, location="args")
- parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- parser.add_argument(
- "sort_by",
- type=str,
- choices=["created_at", "-created_at", "updated_at", "-updated_at"],
- required=False,
- default="-updated_at",
- location="args",
- )
- args = parser.parse_args()
+ args = conversation_list_parser.parse_args()
try:
with Session(db.engine) as session:
@@ -59,10 +108,22 @@ class ConversationApi(Resource):
raise NotFound("Last Conversation Not Exists.")
+@service_api_ns.route("/conversations/")
class ConversationDetailApi(Resource):
+ @service_api_ns.doc("delete_conversation")
+ @service_api_ns.doc(description="Delete a specific conversation")
+ @service_api_ns.doc(params={"c_id": "Conversation ID"})
+ @service_api_ns.doc(
+ responses={
+ 204: "Conversation deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @marshal_with(conversation_delete_fields)
+ @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
def delete(self, app_model: App, end_user: EndUser, c_id):
+ """Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@@ -76,20 +137,30 @@ class ConversationDetailApi(Resource):
return {"result": "success"}, 204
+@service_api_ns.route("/conversations//name")
class ConversationRenameApi(Resource):
+ @service_api_ns.expect(conversation_rename_parser)
+ @service_api_ns.doc("rename_conversation")
+ @service_api_ns.doc(description="Rename a conversation or auto-generate a name")
+ @service_api_ns.doc(params={"c_id": "Conversation ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Conversation renamed successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @marshal_with(simple_conversation_fields)
+ @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
+ """Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
- parser = reqparse.RequestParser()
- parser.add_argument("name", type=str, required=False, location="json")
- parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
- args = parser.parse_args()
+ args = conversation_rename_parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
@@ -97,10 +168,26 @@ class ConversationRenameApi(Resource):
raise NotFound("Conversation Not Exists.")
+@service_api_ns.route("/conversations//variables")
class ConversationVariablesApi(Resource):
+ @service_api_ns.expect(conversation_variables_parser)
+ @service_api_ns.doc("list_conversation_variables")
+ @service_api_ns.doc(description="List all variables for a conversation")
+ @service_api_ns.doc(params={"c_id": "Conversation ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Variables retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @marshal_with(conversation_variable_infinite_scroll_pagination_fields)
+ @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser, c_id):
+ """List all variables for a conversation.
+
+ Conversational variables are only available for chat applications.
+ """
# conversational variable only for chat app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -108,10 +195,7 @@ class ConversationVariablesApi(Resource):
conversation_id = str(c_id)
- parser = reqparse.RequestParser()
- parser.add_argument("last_id", type=uuid_value, location="args")
- parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- args = parser.parse_args()
+ args = conversation_variables_parser.parse_args()
try:
return ConversationService.get_conversational_variable(
@@ -121,11 +205,28 @@ class ConversationVariablesApi(Resource):
raise NotFound("Conversation Not Exists.")
+@service_api_ns.route("/conversations//variables/")
class ConversationVariableDetailApi(Resource):
+ @service_api_ns.expect(conversation_variable_update_parser)
+ @service_api_ns.doc("update_conversation_variable")
+ @service_api_ns.doc(description="Update a conversation variable's value")
+ @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Variable updated successfully",
+ 400: "Bad request - type mismatch",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation or variable not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @marshal_with(conversation_variable_fields)
+ @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
- """Update a conversation variable's value"""
+ """Update a conversation variable's value.
+
+ Allows updating the value of a specific conversation variable.
+ The value must match the variable's expected type.
+ """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@@ -133,12 +234,7 @@ class ConversationVariableDetailApi(Resource):
conversation_id = str(c_id)
variable_id = str(variable_id)
- parser = reqparse.RequestParser()
- # using lambda is for passing the already-typed value without modification
- # if no lambda, it will be converted to string
- # the string cannot be converted using json.loads
- parser.add_argument("value", required=True, location="json", type=lambda x: x)
- args = parser.parse_args()
+ args = conversation_variable_update_parser.parse_args()
try:
return ConversationService.update_conversation_variable(
@@ -150,15 +246,3 @@ class ConversationVariableDetailApi(Resource):
raise NotFound("Conversation Variable Not Exists.")
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
raise BadRequest(str(e))
-
-
-api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name")
-api.add_resource(ConversationApi, "/conversations")
-api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail")
-api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables")
-api.add_resource(
- ConversationVariableDetailApi,
- "/conversations//variables/",
- endpoint="conversation_variable_detail",
- methods=["PUT"],
-)
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index 37153ca5db..05f27545b3 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -1,5 +1,6 @@
from flask import request
-from flask_restful import Resource, marshal_with
+from flask_restx import Resource
+from flask_restx.api import HTTPStatus
import services
from controllers.common.errors import (
@@ -9,17 +10,33 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
-from fields.file_fields import file_fields
+from fields.file_fields import build_file_model
from models.model import App, EndUser
from services.file_service import FileService
+@service_api_ns.route("/files/upload")
class FileApi(Resource):
+ @service_api_ns.doc("upload_file")
+ @service_api_ns.doc(description="Upload a file for use in conversations")
+ @service_api_ns.doc(
+ responses={
+ 201: "File uploaded successfully",
+ 400: "Bad request - no file or invalid file",
+ 401: "Unauthorized - invalid API token",
+ 413: "File too large",
+ 415: "Unsupported file type",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
- @marshal_with(file_fields)
+ @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App, end_user: EndUser):
+ """Upload a file for use in conversations.
+
+ Accepts a single file upload via multipart/form-data.
+ """
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -47,6 +64,3 @@ class FileApi(Resource):
raise UnsupportedFileTypeError()
return upload_file, 201
-
-
-api.add_resource(FileApi, "/files/upload")
diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py
index 57141033d1..84d80ea101 100644
--- a/api/controllers/service_api/app/file_preview.py
+++ b/api/controllers/service_api/app/file_preview.py
@@ -2,9 +2,9 @@ import logging
from urllib.parse import quote
from flask import Response
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
FileAccessDeniedError,
FileNotFoundError,
@@ -17,6 +17,14 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__)
+# Define parser for file preview API
+file_preview_parser = reqparse.RequestParser()
+file_preview_parser.add_argument(
+ "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
+)
+
+
+@service_api_ns.route("/files//preview")
class FilePreviewApi(Resource):
"""
Service API File Preview endpoint
@@ -25,33 +33,30 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context.
"""
+ @service_api_ns.expect(file_preview_parser)
+ @service_api_ns.doc("preview_file")
+ @service_api_ns.doc(description="Preview or download a file uploaded via Service API")
+ @service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
+ @service_api_ns.doc(
+ responses={
+ 200: "File retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - file access denied",
+ 404: "File not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: str):
"""
- Preview/Download a file that was uploaded via Service API
+ Preview/Download a file that was uploaded via Service API.
- Args:
- app_model: The authenticated app model
- end_user: The authenticated end user (optional)
- file_id: UUID of the file to preview
-
- Query Parameters:
- user: Optional user identifier
- as_attachment: Boolean, whether to download as attachment (default: false)
-
- Returns:
- Stream response with file content
-
- Raises:
- FileNotFoundError: File does not exist
- FileAccessDeniedError: File access denied (not owned by app)
+ Provides secure file preview/download functionality.
+ Files can only be accessed if they belong to messages within the requesting app's context.
"""
file_id = str(file_id)
# Parse query parameters
- parser = reqparse.RequestParser()
- parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
- args = parser.parse_args()
+ args = file_preview_parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
@@ -180,7 +185,3 @@ class FilePreviewApi(Resource):
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
return response
-
-
-# Register the API endpoint
-api.add_resource(FilePreviewApi, "/files//preview")
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index a4f95cb1cb..ad3fac7009 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,17 +1,17 @@
import json
import logging
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Api, Namespace, Resource, fields, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
-from fields.conversation_fields import message_file_fields
-from fields.message_fields import agent_thought_fields, feedback_fields
+from fields.conversation_fields import build_message_file_model
+from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
@@ -22,8 +22,37 @@ from services.errors.message import (
)
from services.message_service import MessageService
+# Define parsers for message APIs
+message_list_parser = reqparse.RequestParser()
+message_list_parser.add_argument(
+ "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID"
+)
+message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
+message_list_parser.add_argument(
+ "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return"
+)
-class MessageListApi(Resource):
+message_feedback_parser = reqparse.RequestParser()
+message_feedback_parser.add_argument(
+ "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating"
+)
+message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content")
+
+feedback_list_parser = reqparse.RequestParser()
+feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number")
+feedback_list_parser.add_argument(
+ "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page"
+)
+
+
+def build_message_model(api_or_ns: Api | Namespace):
+ """Build the message model for the API or Namespace."""
+ # First build the nested models
+ feedback_model = build_feedback_model(api_or_ns)
+ agent_thought_model = build_agent_thought_model(api_or_ns)
+ message_file_model = build_message_file_model(api_or_ns)
+
+ # Then build the message fields with nested models
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
@@ -31,37 +60,58 @@ class MessageListApi(Resource):
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+ "message_files": fields.List(fields.Nested(message_file_model)),
+ "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+ "agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"status": fields.String,
"error": fields.String,
}
+ return api_or_ns.model("Message", message_fields)
+
+
+def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
+ """Build the message infinite scroll pagination model for the API or Namespace."""
+ # Build the nested message model first
+ message_model = build_message_model(api_or_ns)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
+ "data": fields.List(fields.Nested(message_model)),
}
+ return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
+
+@service_api_ns.route("/messages")
+class MessageListApi(Resource):
+ @service_api_ns.expect(message_list_parser)
+ @service_api_ns.doc("list_messages")
+ @service_api_ns.doc(description="List messages in a conversation")
+ @service_api_ns.doc(
+ responses={
+ 200: "Messages retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Conversation or first message not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @marshal_with(message_infinite_scroll_pagination_fields)
+ @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
+ """List messages in a conversation.
+
+ Retrieves messages with pagination support using first_id.
+ """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = reqparse.RequestParser()
- parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
- parser.add_argument("first_id", type=uuid_value, location="args")
- parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- args = parser.parse_args()
+ args = message_list_parser.parse_args()
try:
return MessageService.pagination_by_first_id(
@@ -73,15 +123,28 @@ class MessageListApi(Resource):
raise NotFound("First Message Not Exists.")
+@service_api_ns.route("/messages//feedbacks")
class MessageFeedbackApi(Resource):
+ @service_api_ns.expect(message_feedback_parser)
+ @service_api_ns.doc("create_message_feedback")
+ @service_api_ns.doc(description="Submit feedback for a message")
+ @service_api_ns.doc(params={"message_id": "Message ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Feedback submitted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Message not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id):
+ """Submit feedback for a message.
+
+ Allows users to rate messages as like/dislike and provide optional feedback content.
+ """
message_id = str(message_id)
- parser = reqparse.RequestParser()
- parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
- parser.add_argument("content", type=str, location="json")
- args = parser.parse_args()
+ args = message_feedback_parser.parse_args()
try:
MessageService.create_feedback(
@@ -97,21 +160,48 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
+@service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource):
+ @service_api_ns.expect(feedback_list_parser)
+ @service_api_ns.doc("get_app_feedbacks")
+ @service_api_ns.doc(description="Get all feedbacks for the application")
+ @service_api_ns.doc(
+ responses={
+ 200: "Feedbacks retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_app_token
def get(self, app_model: App):
- """Get All Feedbacks of an app"""
- parser = reqparse.RequestParser()
- parser.add_argument("page", type=int, default=1, location="args")
- parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args")
- args = parser.parse_args()
+ """Get all feedbacks for the application.
+
+ Returns paginated list of all feedback submitted for messages in this app.
+ """
+ args = feedback_list_parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
return {"data": feedbacks}
+@service_api_ns.route("/messages//suggested")
class MessageSuggestedApi(Resource):
+ @service_api_ns.doc("get_suggested_questions")
+ @service_api_ns.doc(description="Get suggested follow-up questions for a message")
+ @service_api_ns.doc(params={"message_id": "Message ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Suggested questions retrieved successfully",
+ 400: "Suggested questions feature is disabled",
+ 401: "Unauthorized - invalid API token",
+ 404: "Message not found",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id):
+ """Get suggested follow-up questions for a message.
+
+ Returns AI-generated follow-up questions based on the message content.
+ """
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -130,9 +220,3 @@ class MessageSuggestedApi(Resource):
raise InternalServerError()
return {"result": "success", "data": questions}
-
-
-api.add_resource(MessageListApi, "/messages")
-api.add_resource(MessageFeedbackApi, "/messages//feedbacks")
-api.add_resource(MessageSuggestedApi, "/messages//suggested")
-api.add_resource(AppGetFeedbacksApi, "/app/feedbacks")
diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py
index c157b39f6b..9f8324a84e 100644
--- a/api/controllers/service_api/app/site.py
+++ b/api/controllers/service_api/app/site.py
@@ -1,30 +1,41 @@
-from flask_restful import Resource, marshal_with
+from flask_restx import Resource
from werkzeug.exceptions import Forbidden
-from controllers.common import fields
-from controllers.service_api import api
+from controllers.common.fields import build_site_model
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
+@service_api_ns.route("/site")
class AppSiteApi(Resource):
"""Resource for app sites."""
+ @service_api_ns.doc("get_app_site")
+ @service_api_ns.doc(description="Get application site configuration")
+ @service_api_ns.doc(
+ responses={
+ 200: "Site configuration retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - site not found or tenant archived",
+ }
+ )
@validate_app_token
- @marshal_with(fields.site_fields)
+ @service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
- """Retrieve app site info."""
+ """Retrieve app site info.
+
+ Returns the site configuration for the application including theme, icons, and text.
+ """
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
+ assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
-
-
-api.add_resource(AppSiteApi, "/site")
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index cd8a5f03ac..19e2e67d7f 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -2,12 +2,12 @@ import logging
from dateutil.parser import isoparse
from flask import request
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import Api, Namespace, Resource, fields, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
CompletionRequestError,
NotWorkflowAppError,
@@ -28,7 +28,7 @@ from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
+from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
@@ -40,6 +40,34 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
+# Define parsers for workflow APIs
+workflow_run_parser = reqparse.RequestParser()
+workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+workflow_run_parser.add_argument("files", type=list, required=False, location="json")
+workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+
+workflow_log_parser = reqparse.RequestParser()
+workflow_log_parser.add_argument("keyword", type=str, location="args")
+workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
+workflow_log_parser.add_argument("created_at__before", type=str, location="args")
+workflow_log_parser.add_argument("created_at__after", type=str, location="args")
+workflow_log_parser.add_argument(
+ "created_by_end_user_session_id",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+)
+workflow_log_parser.add_argument(
+ "created_by_account",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+)
+workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
+workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
+
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
@@ -55,12 +83,29 @@ workflow_run_fields = {
}
+def build_workflow_run_model(api_or_ns: Api | Namespace):
+ """Build the workflow run model for the API or Namespace."""
+ return api_or_ns.model("WorkflowRun", workflow_run_fields)
+
+
+@service_api_ns.route("/workflows/run/")
class WorkflowRunDetailApi(Resource):
+ @service_api_ns.doc("get_workflow_run_detail")
+ @service_api_ns.doc(description="Get workflow run details")
+ @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Workflow run details retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Workflow run not found",
+ }
+ )
@validate_app_token
- @marshal_with(workflow_run_fields)
+ @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
def get(self, app_model: App, workflow_run_id: str):
- """
- Get a workflow task running detail
+ """Get a workflow task running detail.
+
+ Returns detailed information about a specific workflow run.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
@@ -78,21 +123,33 @@ class WorkflowRunDetailApi(Resource):
return workflow_run
+@service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource):
+ @service_api_ns.expect(workflow_run_parser)
+ @service_api_ns.doc("run_workflow")
+ @service_api_ns.doc(description="Execute a workflow")
+ @service_api_ns.doc(
+ responses={
+ 200: "Workflow executed successfully",
+ 400: "Bad request - invalid parameters or workflow issues",
+ 401: "Unauthorized - invalid API token",
+ 404: "Workflow not found",
+ 429: "Rate limit exceeded",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
- """
- Run workflow
+ """Execute a workflow.
+
+ Runs a workflow with the provided inputs and returns the results.
+ Supports both blocking and streaming response modes.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
- parser = reqparse.RequestParser()
- parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- parser.add_argument("files", type=list, required=False, location="json")
- parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- args = parser.parse_args()
+ args = workflow_run_parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
@@ -121,21 +178,33 @@ class WorkflowRunApi(Resource):
raise InternalServerError()
+@service_api_ns.route("/workflows//run")
class WorkflowRunByIdApi(Resource):
+ @service_api_ns.expect(workflow_run_parser)
+ @service_api_ns.doc("run_workflow_by_id")
+ @service_api_ns.doc(description="Execute a specific workflow by ID")
+ @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Workflow executed successfully",
+ 400: "Bad request - invalid parameters or workflow issues",
+ 401: "Unauthorized - invalid API token",
+ 404: "Workflow not found",
+ 429: "Rate limit exceeded",
+ 500: "Internal server error",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
- """
- Run specific workflow by ID
+ """Run specific workflow by ID.
+
+ Executes a specific workflow version identified by its ID.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
- parser = reqparse.RequestParser()
- parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- parser.add_argument("files", type=list, required=False, location="json")
- parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- args = parser.parse_args()
+ args = workflow_run_parser.parse_args()
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
@@ -174,12 +243,21 @@ class WorkflowRunByIdApi(Resource):
raise InternalServerError()
+@service_api_ns.route("/workflows/tasks//stop")
class WorkflowTaskStopApi(Resource):
+ @service_api_ns.doc("stop_workflow_task")
+ @service_api_ns.doc(description="Stop a running workflow task")
+ @service_api_ns.doc(params={"task_id": "Task ID to stop"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Task stopped successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Task not found",
+ }
+ )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
- """
- Stop workflow task
- """
+ """Stop a running workflow task."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
@@ -189,35 +267,25 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"}
+@service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource):
+ @service_api_ns.expect(workflow_log_parser)
+ @service_api_ns.doc("get_workflow_logs")
+ @service_api_ns.doc(description="Get workflow execution logs")
+ @service_api_ns.doc(
+ responses={
+ 200: "Logs retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_app_token
- @marshal_with(workflow_app_log_pagination_fields)
+ @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
def get(self, app_model: App):
+ """Get workflow app logs.
+
+ Returns paginated workflow execution logs with filtering options.
"""
- Get workflow app logs
- """
- parser = reqparse.RequestParser()
- parser.add_argument("keyword", type=str, location="args")
- parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
- parser.add_argument("created_at__before", type=str, location="args")
- parser.add_argument("created_at__after", type=str, location="args")
- parser.add_argument(
- "created_by_end_user_session_id",
- type=str,
- location="args",
- required=False,
- default=None,
- )
- parser.add_argument(
- "created_by_account",
- type=str,
- location="args",
- required=False,
- default=None,
- )
- parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
- parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
- args = parser.parse_args()
+ args = workflow_log_parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
@@ -243,10 +311,3 @@ class WorkflowAppLogApi(Resource):
)
return workflow_app_log_pagination
-
-
-api.add_resource(WorkflowRunApi, "/workflows/run")
-api.add_resource(WorkflowRunDetailApi, "/workflows/run/")
-api.add_resource(WorkflowRunByIdApi, "/workflows//run")
-api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop")
-api.add_resource(WorkflowAppLogApi, "/workflows/logs")
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 35b1efeff6..c486b0480b 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -1,11 +1,11 @@
from typing import Literal
from flask import request
-from flask_restful import marshal, marshal_with, reqparse
+from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import (
DatasetApiResource,
@@ -16,7 +16,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
-from fields.tag_fields import tag_fields
+from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -36,12 +36,171 @@ def _validate_description_length(description):
return description
+# Define parsers for dataset operations
+dataset_create_parser = reqparse.RequestParser()
+dataset_create_parser.add_argument(
+ "name",
+ nullable=False,
+ required=True,
+ help="type is required. Name must be between 1 to 40 characters.",
+ type=_validate_name,
+)
+dataset_create_parser.add_argument(
+ "description",
+ type=_validate_description_length,
+ nullable=True,
+ required=False,
+ default="",
+)
+dataset_create_parser.add_argument(
+ "indexing_technique",
+ type=str,
+ location="json",
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ help="Invalid indexing technique.",
+)
+dataset_create_parser.add_argument(
+ "permission",
+ type=str,
+ location="json",
+ choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
+ help="Invalid permission.",
+ required=False,
+ nullable=False,
+)
+dataset_create_parser.add_argument(
+ "external_knowledge_api_id",
+ type=str,
+ nullable=True,
+ required=False,
+ default="_validate_name",
+)
+dataset_create_parser.add_argument(
+ "provider",
+ type=str,
+ nullable=True,
+ required=False,
+ default="vendor",
+)
+dataset_create_parser.add_argument(
+ "external_knowledge_id",
+ type=str,
+ nullable=True,
+ required=False,
+)
+dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
+dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
+
+dataset_update_parser = reqparse.RequestParser()
+dataset_update_parser.add_argument(
+ "name",
+ nullable=False,
+ help="type is required. Name must be between 1 to 40 characters.",
+ type=_validate_name,
+)
+dataset_update_parser.add_argument(
+ "description", location="json", store_missing=False, type=_validate_description_length
+)
+dataset_update_parser.add_argument(
+ "indexing_technique",
+ type=str,
+ location="json",
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ nullable=True,
+ help="Invalid indexing technique.",
+)
+dataset_update_parser.add_argument(
+ "permission",
+ type=str,
+ location="json",
+ choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
+ help="Invalid permission.",
+)
+dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
+dataset_update_parser.add_argument(
+ "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
+)
+dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
+dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
+dataset_update_parser.add_argument(
+ "external_retrieval_model",
+ type=dict,
+ required=False,
+ nullable=True,
+ location="json",
+ help="Invalid external retrieval model.",
+)
+dataset_update_parser.add_argument(
+ "external_knowledge_id",
+ type=str,
+ required=False,
+ nullable=True,
+ location="json",
+ help="Invalid external knowledge id.",
+)
+dataset_update_parser.add_argument(
+ "external_knowledge_api_id",
+ type=str,
+ required=False,
+ nullable=True,
+ location="json",
+ help="Invalid external knowledge api id.",
+)
+
+tag_create_parser = reqparse.RequestParser()
+tag_create_parser.add_argument(
+ "name",
+ nullable=False,
+ required=True,
+ help="Name must be between 1 to 50 characters.",
+ type=lambda x: x
+ if x and 1 <= len(x) <= 50
+ else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
+)
+
+tag_update_parser = reqparse.RequestParser()
+tag_update_parser.add_argument(
+ "name",
+ nullable=False,
+ required=True,
+ help="Name must be between 1 to 50 characters.",
+ type=lambda x: x
+ if x and 1 <= len(x) <= 50
+ else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
+)
+tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
+
+tag_delete_parser = reqparse.RequestParser()
+tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
+
+tag_binding_parser = reqparse.RequestParser()
+tag_binding_parser.add_argument(
+ "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
+)
+tag_binding_parser.add_argument(
+ "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
+)
+
+tag_unbinding_parser = reqparse.RequestParser()
+tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
+tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
+
+
+@service_api_ns.route("/datasets")
class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""
+ @service_api_ns.doc("list_datasets")
+ @service_api_ns.doc(description="List all datasets")
+ @service_api_ns.doc(
+ responses={
+ 200: "Datasets retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
def get(self, tenant_id):
"""Resource for getting datasets."""
-
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
# provider = request.args.get("provider", default="vendor")
@@ -76,65 +235,20 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
+ @service_api_ns.expect(dataset_create_parser)
+ @service_api_ns.doc("create_dataset")
+ @service_api_ns.doc(description="Create a new dataset")
+ @service_api_ns.doc(
+ responses={
+ 200: "Dataset created successfully",
+ 401: "Unauthorized - invalid API token",
+ 400: "Bad request - invalid parameters",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
- parser = reqparse.RequestParser()
- parser.add_argument(
- "name",
- nullable=False,
- required=True,
- help="type is required. Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- parser.add_argument(
- "description",
- type=_validate_description_length,
- nullable=True,
- required=False,
- default="",
- )
- parser.add_argument(
- "indexing_technique",
- type=str,
- location="json",
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- help="Invalid indexing technique.",
- )
- parser.add_argument(
- "permission",
- type=str,
- location="json",
- choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
- help="Invalid permission.",
- required=False,
- nullable=False,
- )
- parser.add_argument(
- "external_knowledge_api_id",
- type=str,
- nullable=True,
- required=False,
- default="_validate_name",
- )
- parser.add_argument(
- "provider",
- type=str,
- nullable=True,
- required=False,
- default="vendor",
- )
- parser.add_argument(
- "external_knowledge_id",
- type=str,
- nullable=True,
- required=False,
- )
- parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
- parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
- parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
-
- args = parser.parse_args()
+ args = dataset_create_parser.parse_args()
if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
@@ -174,9 +288,21 @@ class DatasetListApi(DatasetApiResource):
return marshal(dataset, dataset_detail_fields), 200
+@service_api_ns.route("/datasets/")
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""
+ @service_api_ns.doc("get_dataset")
+ @service_api_ns.doc(description="Get a specific dataset by ID")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Dataset retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Dataset not found",
+ }
+ )
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -216,6 +342,18 @@ class DatasetApi(DatasetApiResource):
return data, 200
+ @service_api_ns.expect(dataset_update_parser)
+ @service_api_ns.doc("update_dataset")
+ @service_api_ns.doc(description="Update an existing dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Dataset updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Dataset not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
@@ -223,63 +361,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
- parser = reqparse.RequestParser()
- parser.add_argument(
- "name",
- nullable=False,
- help="type is required. Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
- parser.add_argument(
- "indexing_technique",
- type=str,
- location="json",
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- help="Invalid indexing technique.",
- )
- parser.add_argument(
- "permission",
- type=str,
- location="json",
- choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
- help="Invalid permission.",
- )
- parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
- parser.add_argument(
- "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
- )
- parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
- parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
-
- parser.add_argument(
- "external_retrieval_model",
- type=dict,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external retrieval model.",
- )
-
- parser.add_argument(
- "external_knowledge_id",
- type=str,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external knowledge id.",
- )
-
- parser.add_argument(
- "external_knowledge_api_id",
- type=str,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external knowledge api id.",
- )
- args = parser.parse_args()
+ args = dataset_update_parser.parse_args()
data = request.get_json()
# check embedding model setting
@@ -327,6 +409,17 @@ class DatasetApi(DatasetApiResource):
return result_data, 200
+ @service_api_ns.doc("delete_dataset")
+ @service_api_ns.doc(description="Delete a dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 204: "Dataset deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ 409: "Conflict - dataset is in use",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id):
"""
@@ -357,9 +450,27 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError()
+@service_api_ns.route("/datasets//documents/status/")
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
+ @service_api_ns.doc("update_document_status")
+ @service_api_ns.doc(description="Batch update document status")
+ @service_api_ns.doc(
+ params={
+ "dataset_id": "Dataset ID",
+ "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
+ }
+ )
+ @service_api_ns.doc(
+ responses={
+ 200: "Document status updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Dataset not found",
+ 400: "Bad request - invalid action",
+ }
+ )
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
"""
Batch update document status.
@@ -407,53 +518,65 @@ class DocumentStatusApi(DatasetApiResource):
return {"result": "success"}, 200
+@service_api_ns.route("/datasets/tags")
class DatasetTagsApi(DatasetApiResource):
+ @service_api_ns.doc("list_dataset_tags")
+ @service_api_ns.doc(description="Get all knowledge type tags")
+ @service_api_ns.doc(
+ responses={
+ 200: "Tags retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_dataset_token
- @marshal_with(tag_fields)
+ @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id):
"""Get all knowledge type tags."""
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
return tags, 200
+ @service_api_ns.expect(tag_create_parser)
+ @service_api_ns.doc("create_dataset_tag")
+ @service_api_ns.doc(description="Add a knowledge type tag")
+ @service_api_ns.doc(
+ responses={
+ 200: "Tag created successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ }
+ )
+ @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def post(self, _, dataset_id):
"""Add a knowledge type tag."""
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
- parser = reqparse.RequestParser()
- parser.add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name must be between 1 to 50 characters.",
- type=DatasetTagsApi._validate_tag_name,
- )
-
- args = parser.parse_args()
+ args = tag_create_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
-
return response, 200
+ @service_api_ns.expect(tag_update_parser)
+ @service_api_ns.doc("update_dataset_tag")
+ @service_api_ns.doc(description="Update a knowledge type tag")
+ @service_api_ns.doc(
+ responses={
+ 200: "Tag updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ }
+ )
+ @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def patch(self, _, dataset_id):
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
- parser = reqparse.RequestParser()
- parser.add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name must be between 1 to 50 characters.",
- type=DatasetTagsApi._validate_tag_name,
- )
- parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
- args = parser.parse_args()
+ args = tag_update_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id"))
@@ -463,66 +586,88 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200
+ @service_api_ns.expect(tag_delete_parser)
+ @service_api_ns.doc("delete_dataset_tag")
+ @service_api_ns.doc(description="Delete a knowledge type tag")
+ @service_api_ns.doc(
+ responses={
+ 204: "Tag deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ }
+ )
@validate_dataset_token
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
if not current_user.is_editor:
raise Forbidden()
- parser = reqparse.RequestParser()
- parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
- args = parser.parse_args()
+ args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id"))
return 204
- @staticmethod
- def _validate_tag_name(name):
- if not name or len(name) < 1 or len(name) > 50:
- raise ValueError("Name must be between 1 to 50 characters.")
- return name
-
+@service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource):
+ @service_api_ns.expect(tag_binding_parser)
+ @service_api_ns.doc("bind_dataset_tags")
+ @service_api_ns.doc(description="Bind tags to a dataset")
+ @service_api_ns.doc(
+ responses={
+ 204: "Tags bound successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ }
+ )
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
- parser = reqparse.RequestParser()
- parser.add_argument(
- "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
- )
- parser.add_argument(
- "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
- )
-
- args = parser.parse_args()
+ args = tag_binding_parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)
return 204
+@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
+ @service_api_ns.expect(tag_unbinding_parser)
+ @service_api_ns.doc("unbind_dataset_tag")
+ @service_api_ns.doc(description="Unbind a tag from a dataset")
+ @service_api_ns.doc(
+ responses={
+ 204: "Tag unbound successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ }
+ )
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
- parser = reqparse.RequestParser()
- parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
- parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
-
- args = parser.parse_args()
+ args = tag_unbinding_parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)
return 204
+@service_api_ns.route("/datasets//tags")
class DatasetTagsBindingStatusApi(DatasetApiResource):
+ @service_api_ns.doc("get_dataset_tags_binding_status")
+ @service_api_ns.doc(description="Get tags bound to a specific dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Tags retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
@@ -531,12 +676,3 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)}
return response, 200
-
-
-api.add_resource(DatasetListApi, "/datasets")
-api.add_resource(DatasetApi, "/datasets/")
-api.add_resource(DocumentStatusApi, "/datasets//documents/status/")
-api.add_resource(DatasetTagsApi, "/datasets/tags")
-api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
-api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
-api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags")
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index d0354f7851..43232229c8 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -1,7 +1,7 @@
import json
from flask import request
-from flask_restful import marshal, reqparse
+from flask_restx import marshal, reqparse
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -13,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
@@ -34,32 +34,64 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
+# Define parsers for document operations
+document_text_create_parser = reqparse.RequestParser()
+document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json")
+document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json")
+document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
+document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json")
+document_text_create_parser.add_argument(
+ "doc_form", type=str, default="text_model", required=False, nullable=False, location="json"
+)
+document_text_create_parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+)
+document_text_create_parser.add_argument(
+ "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
+)
+document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
+document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+document_text_create_parser.add_argument(
+ "embedding_model_provider", type=str, required=False, nullable=True, location="json"
+)
+document_text_update_parser = reqparse.RequestParser()
+document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json")
+document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json")
+document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
+document_text_update_parser.add_argument(
+ "doc_form", type=str, default="text_model", required=False, nullable=False, location="json"
+)
+document_text_update_parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+)
+document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
+
+
+@service_api_ns.route(
+ "/datasets//document/create_by_text",
+ "/datasets//document/create-by-text",
+)
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
+ @service_api_ns.expect(document_text_create_parser)
+ @service_api_ns.doc("create_document_by_text")
+ @service_api_ns.doc(description="Create a new document by providing text content")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Document created successfully",
+ 401: "Unauthorized - invalid API token",
+ 400: "Bad request - invalid parameters",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
- parser = reqparse.RequestParser()
- parser.add_argument("name", type=str, required=True, nullable=False, location="json")
- parser.add_argument("text", type=str, required=True, nullable=False, location="json")
- parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
- parser.add_argument("original_document_id", type=str, required=False, location="json")
- parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- parser.add_argument(
- "doc_language", type=str, default="English", required=False, nullable=False, location="json"
- )
- parser.add_argument(
- "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
- )
- parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
- parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
- parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
-
- args = parser.parse_args()
+ args = document_text_create_parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@@ -117,23 +149,29 @@ class DocumentAddByTextApi(DatasetApiResource):
return documents_and_batch_fields, 200
+@service_api_ns.route(
+ "/datasets//documents//update_by_text",
+ "/datasets//documents//update-by-text",
+)
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
+ @service_api_ns.expect(document_text_update_parser)
+ @service_api_ns.doc("update_document_by_text")
+ @service_api_ns.doc(description="Update an existing document by providing text content")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Document updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Document not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
- parser = reqparse.RequestParser()
- parser.add_argument("name", type=str, required=False, nullable=True, location="json")
- parser.add_argument("text", type=str, required=False, nullable=True, location="json")
- parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
- parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- parser.add_argument(
- "doc_language", type=str, default="English", required=False, nullable=False, location="json"
- )
- parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
- args = parser.parse_args()
+ args = document_text_update_parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -187,9 +225,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
return documents_and_batch_fields, 200
+@service_api_ns.route(
+ "/datasets//document/create_by_file",
+ "/datasets//document/create-by-file",
+)
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
+ @service_api_ns.doc("create_document_by_file")
+ @service_api_ns.doc(description="Create a new document by uploading a file")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Document created successfully",
+ 401: "Unauthorized - invalid API token",
+ 400: "Bad request - invalid file or parameters",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
@@ -281,9 +333,23 @@ class DocumentAddByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
+@service_api_ns.route(
+ "/datasets//documents//update_by_file",
+ "/datasets//documents//update-by-file",
+)
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
+ @service_api_ns.doc("update_document_by_file")
+ @service_api_ns.doc(description="Update an existing document by uploading a file")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Document updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Document not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
@@ -358,7 +424,18 @@ class DocumentUpdateByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
+@service_api_ns.route("/datasets//documents")
class DocumentListApi(DatasetApiResource):
+ @service_api_ns.doc("list_documents")
+ @service_api_ns.doc(description="List all documents in a dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Documents retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@@ -391,7 +468,18 @@ class DocumentListApi(DatasetApiResource):
return response
+@service_api_ns.route("/datasets//documents//indexing-status")
class DocumentIndexingStatusApi(DatasetApiResource):
+ @service_api_ns.doc("get_document_indexing_status")
+ @service_api_ns.doc(description="Get indexing status for documents in a batch")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Indexing status retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset or documents not found",
+ }
+ )
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
@@ -440,9 +528,21 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
+@service_api_ns.route("/datasets//documents/")
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
+ @service_api_ns.doc("get_document")
+ @service_api_ns.doc(description="Get a specific document by ID")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Document retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - insufficient permissions",
+ 404: "Document not found",
+ }
+ )
def get(self, tenant_id, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
@@ -534,6 +634,17 @@ class DocumentApi(DatasetApiResource):
return response
+ @service_api_ns.doc("delete_document")
+ @service_api_ns.doc(description="Delete a document")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 204: "Document deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 403: "Forbidden - document is archived",
+ 404: "Document not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
@@ -564,28 +675,3 @@ class DocumentApi(DatasetApiResource):
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
-
-
-api.add_resource(
- DocumentAddByTextApi,
- "/datasets//document/create_by_text",
- "/datasets//document/create-by-text",
-)
-api.add_resource(
- DocumentAddByFileApi,
- "/datasets//document/create_by_file",
- "/datasets//document/create-by-file",
-)
-api.add_resource(
- DocumentUpdateByTextApi,
- "/datasets//documents//update_by_text",
- "/datasets//documents//update-by-text",
-)
-api.add_resource(
- DocumentUpdateByFileApi,
- "/datasets//documents//update_by_file",
- "/datasets//documents//update-by-file",
-)
-api.add_resource(DocumentApi, "/datasets//documents/")
-api.add_resource(DocumentListApi, "/datasets//documents")
-api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status")
diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py
index 52e9bca5da..d81287d56f 100644
--- a/api/controllers/service_api/dataset/hit_testing.py
+++ b/api/controllers/service_api/dataset/hit_testing.py
@@ -1,11 +1,26 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
+@service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
+ @service_api_ns.doc("dataset_hit_testing")
+ @service_api_ns.doc(description="Perform hit testing on a dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Hit testing results",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
+ """Perform hit testing on a dataset.
+
+ Tests retrieval performance for the specified dataset.
+ """
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
@@ -13,6 +28,3 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
-
-
-api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve")
diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py
index 75a0b18285..9defe6af03 100644
--- a/api/controllers/service_api/dataset/metadata.py
+++ b/api/controllers/service_api/dataset/metadata.py
@@ -1,10 +1,10 @@
from typing import Literal
from flask_login import current_user # type: ignore
-from flask_restful import marshal, reqparse
+from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
@@ -14,14 +14,43 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
+# Define parsers for metadata APIs
+metadata_create_parser = reqparse.RequestParser()
+metadata_create_parser.add_argument(
+ "type", type=str, required=True, nullable=False, location="json", help="Metadata type"
+)
+metadata_create_parser.add_argument(
+ "name", type=str, required=True, nullable=False, location="json", help="Metadata name"
+)
+metadata_update_parser = reqparse.RequestParser()
+metadata_update_parser.add_argument(
+ "name", type=str, required=True, nullable=False, location="json", help="New metadata name"
+)
+
+document_metadata_parser = reqparse.RequestParser()
+document_metadata_parser.add_argument(
+ "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
+)
+
+
+@service_api_ns.route("/datasets//metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource):
+ @service_api_ns.expect(metadata_create_parser)
+ @service_api_ns.doc("create_dataset_metadata")
+ @service_api_ns.doc(description="Create metadata for a dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 201: "Metadata created successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
- parser = reqparse.RequestParser()
- parser.add_argument("type", type=str, required=True, nullable=False, location="json")
- parser.add_argument("name", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ """Create metadata for a dataset."""
+ args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args)
dataset_id_str = str(dataset_id)
@@ -33,7 +62,18 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201
+ @service_api_ns.doc("get_dataset_metadata")
+ @service_api_ns.doc(description="Get all metadata for a dataset")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Metadata retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
def get(self, tenant_id, dataset_id):
+ """Get all metadata for a dataset."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -41,12 +81,23 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
return MetadataService.get_dataset_metadatas(dataset), 200
+@service_api_ns.route("/datasets//metadata/")
class DatasetMetadataServiceApi(DatasetApiResource):
+ @service_api_ns.expect(metadata_update_parser)
+ @service_api_ns.doc("update_dataset_metadata")
+ @service_api_ns.doc(description="Update metadata name")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Metadata updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset or metadata not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
- parser = reqparse.RequestParser()
- parser.add_argument("name", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ """Update metadata name."""
+ args = metadata_update_parser.parse_args()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@@ -58,8 +109,19 @@ class DatasetMetadataServiceApi(DatasetApiResource):
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return marshal(metadata, dataset_metadata_fields), 200
+ @service_api_ns.doc("delete_dataset_metadata")
+ @service_api_ns.doc(description="Delete metadata")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
+ @service_api_ns.doc(
+ responses={
+ 204: "Metadata deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset or metadata not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id):
+ """Delete metadata."""
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -71,15 +133,37 @@ class DatasetMetadataServiceApi(DatasetApiResource):
return 204
+@service_api_ns.route("/datasets/metadata/built-in")
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
+ @service_api_ns.doc("get_built_in_fields")
+ @service_api_ns.doc(description="Get all built-in metadata fields")
+ @service_api_ns.doc(
+ responses={
+ 200: "Built-in fields retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
def get(self, tenant_id):
+ """Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
+@service_api_ns.route("/datasets//metadata/built-in/")
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
+ @service_api_ns.doc("toggle_built_in_field")
+ @service_api_ns.doc(description="Enable or disable built-in metadata field")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Action completed successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
+ """Enable or disable built-in metadata field."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -93,29 +177,31 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
return 200
+@service_api_ns.route("/datasets//documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource):
+ @service_api_ns.expect(document_metadata_parser)
+ @service_api_ns.doc("update_documents_metadata")
+ @service_api_ns.doc(description="Update metadata for multiple documents")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Documents metadata updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
+ """Update metadata for multiple documents."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
- parser = reqparse.RequestParser()
- parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args)
MetadataService.update_documents_metadata(dataset, metadata_args)
return 200
-
-
-api.add_resource(DatasetMetadataCreateServiceApi, "/datasets//metadata")
-api.add_resource(DatasetMetadataServiceApi, "/datasets//metadata/")
-api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in")
-api.add_resource(
- DatasetMetadataBuiltInFieldActionServiceApi, "/datasets//metadata/built-in/"
-)
-api.add_resource(DocumentMetadataEditServiceApi, "/datasets//documents/metadata")
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index 31f862dc8f..f5e2010ca4 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -1,9 +1,9 @@
from flask import request
from flask_login import current_user
-from flask_restful import marshal, reqparse
+from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
DatasetApiResource,
@@ -19,34 +19,59 @@ from fields.segment_fields import child_chunk_fields, segment_fields
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
-from services.errors.chunk import (
- ChildChunkDeleteIndexError,
- ChildChunkIndexingError,
-)
-from services.errors.chunk import (
- ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError,
-)
-from services.errors.chunk import (
- ChildChunkIndexingError as ChildChunkIndexingServiceError,
-)
+from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
+from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
+from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
+
+# Define parsers for segment operations
+segment_create_parser = reqparse.RequestParser()
+segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
+
+segment_list_parser = reqparse.RequestParser()
+segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args")
+segment_list_parser.add_argument("keyword", type=str, default=None, location="args")
+
+segment_update_parser = reqparse.RequestParser()
+segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
+
+child_chunk_create_parser = reqparse.RequestParser()
+child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+
+child_chunk_list_parser = reqparse.RequestParser()
+child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args")
+child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args")
+child_chunk_list_parser.add_argument("page", type=int, default=1, location="args")
+
+child_chunk_update_parser = reqparse.RequestParser()
+child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+@service_api_ns.route("/datasets//documents//segments")
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
+ @service_api_ns.expect(segment_create_parser)
+ @service_api_ns.doc("create_segments")
+ @service_api_ns.doc(description="Create segments in a document")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Segments created successfully",
+ 400: "Bad request - segments data is missing",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset or document not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def post(self, tenant_id, dataset_id, document_id):
+ def post(self, tenant_id: str, dataset_id: str, document_id: str):
"""Create single segment."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
@@ -71,9 +96,7 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
- parser = reqparse.RequestParser()
- parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
- args = parser.parse_args()
+ args = segment_create_parser.parse_args()
if args["segments"] is not None:
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
@@ -82,18 +105,26 @@ class SegmentApi(DatasetApiResource):
else:
return {"error": "Segments is required"}, 400
- def get(self, tenant_id, dataset_id, document_id):
+ @service_api_ns.expect(segment_list_parser)
+ @service_api_ns.doc("list_segments")
+ @service_api_ns.doc(description="List segments in a document")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Segments retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset or document not found",
+ }
+ )
+ def get(self, tenant_id: str, dataset_id: str, document_id: str):
"""Get segments."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
@@ -114,10 +145,7 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
- parser = reqparse.RequestParser()
- parser.add_argument("status", type=str, action="append", default=[], location="args")
- parser.add_argument("keyword", type=str, default=None, location="args")
- args = parser.parse_args()
+ args = segment_list_parser.parse_args()
segments, total = SegmentService.get_segments(
document_id=document_id,
@@ -140,43 +168,62 @@ class SegmentApi(DatasetApiResource):
return response, 200
+@service_api_ns.route("/datasets//documents//segments/")
class DatasetSegmentApi(DatasetApiResource):
+ @service_api_ns.doc("delete_segment")
+ @service_api_ns.doc(description="Delete a specific segment")
+ @service_api_ns.doc(
+ params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"}
+ )
+ @service_api_ns.doc(
+ responses={
+ 204: "Segment deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or segment not found",
+ }
+ )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def delete(self, tenant_id, dataset_id, document_id, segment_id):
+ def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return 204
+ @service_api_ns.expect(segment_update_parser)
+ @service_api_ns.doc("update_segment")
+ @service_api_ns.doc(description="Update a specific segment")
+ @service_api_ns.doc(
+ params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"}
+ )
+ @service_api_ns.doc(
+ responses={
+ 200: "Segment updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or segment not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def post(self, tenant_id, dataset_id, document_id, segment_id):
+ def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
@@ -197,37 +244,39 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate args
- parser = reqparse.RequestParser()
- parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
- args = parser.parse_args()
+ args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
- def get(self, tenant_id, dataset_id, document_id, segment_id):
+ @service_api_ns.doc("get_segment")
+ @service_api_ns.doc(description="Get a specific segment by ID")
+ @service_api_ns.doc(
+ responses={
+ 200: "Segment retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or segment not found",
+ }
+ )
+ def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -235,29 +284,41 @@ class DatasetSegmentApi(DatasetApiResource):
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
+@service_api_ns.route(
+ "/datasets//documents//segments//child_chunks"
+)
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
+ @service_api_ns.expect(child_chunk_create_parser)
+ @service_api_ns.doc("create_child_chunk")
+ @service_api_ns.doc(description="Create a new child chunk for a segment")
+ @service_api_ns.doc(
+ params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
+ )
+ @service_api_ns.doc(
+ responses={
+ 200: "Child chunk created successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or segment not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def post(self, tenant_id, dataset_id, document_id, segment_id):
+ def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
"""Create child chunk."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -280,43 +341,46 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description)
# validate args
- parser = reqparse.RequestParser()
- parser.add_argument("content", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ args = child_chunk_create_parser.parse_args()
try:
- child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
+ child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
- def get(self, tenant_id, dataset_id, document_id, segment_id):
+ @service_api_ns.expect(child_chunk_list_parser)
+ @service_api_ns.doc("list_child_chunks")
+ @service_api_ns.doc(description="List child chunks for a segment")
+ @service_api_ns.doc(
+ params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
+ )
+ @service_api_ns.doc(
+ responses={
+ 200: "Child chunks retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or segment not found",
+ }
+ )
+ def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
"""Get child chunks."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
- parser = reqparse.RequestParser()
- parser.add_argument("limit", type=int, default=20, location="args")
- parser.add_argument("keyword", type=str, default=None, location="args")
- parser.add_argument("page", type=int, default=1, location="args")
- args = parser.parse_args()
+ args = child_chunk_list_parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
@@ -333,40 +397,63 @@ class ChildChunkApi(DatasetApiResource):
}, 200
+@service_api_ns.route(
+ "/datasets//documents//segments//child_chunks/"
+)
class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks."""
+ @service_api_ns.doc("delete_child_chunk")
+ @service_api_ns.doc(description="Delete a specific child chunk")
+ @service_api_ns.doc(
+ params={
+ "dataset_id": "Dataset ID",
+ "document_id": "Document ID",
+ "segment_id": "Parent segment ID",
+ "child_chunk_id": "Child chunk ID to delete",
+ }
+ )
+ @service_api_ns.doc(
+ responses={
+ 204: "Child chunk deleted successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, segment, or child chunk not found",
+ }
+ )
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
+ def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
"""Delete child chunk."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
- document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
- segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
+ # validate segment belongs to the specified document
+ if segment.document_id != document_id:
+ raise NotFound("Document not found.")
+
# check child chunk
- child_chunk_id = str(child_chunk_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")
+ # validate child chunk belongs to the specified segment
+ if child_chunk.segment_id != segment.id:
+ raise NotFound("Child chunk not found.")
+
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
@@ -374,14 +461,30 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204
+ @service_api_ns.expect(child_chunk_update_parser)
+ @service_api_ns.doc("update_child_chunk")
+ @service_api_ns.doc(description="Update a specific child chunk")
+ @service_api_ns.doc(
+ params={
+ "dataset_id": "Dataset ID",
+ "document_id": "Document ID",
+ "segment_id": "Parent segment ID",
+ "child_chunk_id": "Child chunk ID to update",
+ }
+ )
+ @service_api_ns.doc(
+ responses={
+ 200: "Child chunk updated successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, segment, or child chunk not found",
+ }
+ )
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
- def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
+ def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
"""Update child chunk."""
# check dataset
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -396,6 +499,10 @@ class DatasetChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
+ # validate segment belongs to the specified document
+ if segment.document_id != document_id:
+ raise NotFound("Segment not found.")
+
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
@@ -403,29 +510,16 @@ class DatasetChildChunkApi(DatasetApiResource):
if not child_chunk:
raise NotFound("Child chunk not found.")
+ # validate child chunk belongs to the specified segment
+ if child_chunk.segment_id != segment.id:
+ raise NotFound("Child chunk not found.")
+
# validate args
- parser = reqparse.RequestParser()
- parser.add_argument("content", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ args = child_chunk_update_parser.parse_args()
try:
- child_chunk = SegmentService.update_child_chunk(
- args.get("content"), child_chunk, segment, document, dataset
- )
+ child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
-
-
-api.add_resource(SegmentApi, "/datasets//documents//segments")
-api.add_resource(
- DatasetSegmentApi, "/datasets//documents//segments/"
-)
-api.add_resource(
- ChildChunkApi, "/datasets//documents//segments//child_chunks"
-)
-api.add_resource(
- DatasetChildChunkApi,
- "/datasets//documents//segments//child_chunks/",
-)
diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py
index 3b4721b5b0..27b36a6402 100644
--- a/api/controllers/service_api/dataset/upload_file.py
+++ b/api/controllers/service_api/dataset/upload_file.py
@@ -1,6 +1,6 @@
from werkzeug.exceptions import NotFound
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import (
DatasetApiResource,
)
@@ -11,9 +11,23 @@ from models.model import UploadFile
from services.dataset_service import DocumentService
+@service_api_ns.route("/datasets//documents//upload-file")
class UploadFileApi(DatasetApiResource):
+ @service_api_ns.doc("get_upload_file")
+ @service_api_ns.doc(description="Get upload file information and download URL")
+ @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Upload file information retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ 404: "Dataset, document, or upload file not found",
+ }
+ )
def get(self, tenant_id, dataset_id, document_id):
- """Get upload file."""
+ """Get upload file information and download URL.
+
+ Returns information about an uploaded file including its download URL.
+ """
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@@ -49,6 +63,3 @@ class UploadFileApi(DatasetApiResource):
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
-
-
-api.add_resource(UploadFileApi, "/datasets//documents//upload-file")
diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py
index 9bb5df4c4e..a9d2d6fadc 100644
--- a/api/controllers/service_api/index.py
+++ b/api/controllers/service_api/index.py
@@ -1,9 +1,10 @@
-from flask_restful import Resource
+from flask_restx import Resource
from configs import dify_config
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
+@service_api_ns.route("/")
class IndexApi(Resource):
def get(self):
return {
@@ -11,6 +12,3 @@ class IndexApi(Resource):
"api_version": "v1",
"server_version": dify_config.project.version,
}
-
-
-api.add_resource(IndexApi, "/")
diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py
index 3f18474674..536cf81a2f 100644
--- a/api/controllers/service_api/workspace/models.py
+++ b/api/controllers/service_api/workspace/models.py
@@ -1,21 +1,32 @@
from flask_login import current_user
-from flask_restful import Resource
+from flask_restx import Resource
-from controllers.service_api import api
+from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService
+@service_api_ns.route("/workspaces/current/models/model-types/")
class ModelProviderAvailableModelApi(Resource):
+ @service_api_ns.doc("get_available_models")
+ @service_api_ns.doc(description="Get available models by model type")
+ @service_api_ns.doc(params={"model_type": "Type of model to retrieve"})
+ @service_api_ns.doc(
+ responses={
+ 200: "Models retrieved successfully",
+ 401: "Unauthorized - invalid API token",
+ }
+ )
@validate_dataset_token
def get(self, _, model_type):
+ """Get available models by model type.
+
+ Returns a list of available models for the specified model type.
+ """
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
-
-
-api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/")
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index da81cc8bc3..8aac3de4c3 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -7,7 +7,7 @@ from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
-from flask_restful import Resource
+from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index 197859e8f3..0680903635 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,5 +1,7 @@
+import logging
+
from flask import request
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized
from controllers.common import fields
@@ -87,8 +89,11 @@ class AppWebAuthPermission(Resource):
decoded = PassportService().verify(tk)
user_id = decoded.get("user_id", "visitor")
- except Exception as e:
- pass
+ except Unauthorized:
+ raise
+ except Exception:
+ logging.exception("Unexpected error during auth verification")
+ raise
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index 2919ca9af4..241d0874db 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -65,7 +65,7 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
- from flask_restful import reqparse
+ from flask_restx import reqparse
try:
parser = reqparse.RequestParser()
diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py
index fd3b9aa804..c19afee9b7 100644
--- a/api/controllers/web/completion.py
+++ b/api/controllers/web/completion.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse
+from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py
index 98cea3974f..cea8e442f3 100644
--- a/api/controllers/web/conversation.py
+++ b/api/controllers/web/conversation.py
@@ -1,5 +1,5 @@
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import marshal_with, reqparse
+from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index 0563ed2238..478b3d2e31 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restx import Resource
from controllers.web import api
from services.feature_service import FeatureService
diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py
index 0c30435825..b05e2a2e65 100644
--- a/api/controllers/web/files.py
+++ b/api/controllers/web/files.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import marshal_with
+from flask_restx import marshal_with
import services
from controllers.common.errors import (
diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py
index 0da8d65efc..d436657f06 100644
--- a/api/controllers/web/forgot_password.py
+++ b/api/controllers/web/forgot_password.py
@@ -2,7 +2,7 @@ import base64
import secrets
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py
index 01c4f4a262..d4eafd532b 100644
--- a/api/controllers/web/login.py
+++ b/api/controllers/web/login.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restx import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
import services
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 7bb81cd0d3..f348221d80 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -1,7 +1,7 @@
import logging
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import fields, marshal_with, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.web import api
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
index acd3a8b539..1ac20e6531 100644
--- a/api/controllers/web/passport.py
+++ b/api/controllers/web/passport.py
@@ -2,7 +2,7 @@ import uuid
from datetime import UTC, datetime, timedelta
from flask import request
-from flask_restful import Resource
+from flask_restx import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py
index 4e19716c3d..930b9d96e9 100644
--- a/api/controllers/web/remote_files.py
+++ b/api/controllers/web/remote_files.py
@@ -1,7 +1,7 @@
import urllib.parse
import httpx
-from flask_restful import marshal_with, reqparse
+from flask_restx import marshal_with, reqparse
import services
from controllers.common import helpers
diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py
index d7188ef0b3..a0912499ff 100644
--- a/api/controllers/web/saved_message.py
+++ b/api/controllers/web/saved_message.py
@@ -1,5 +1,5 @@
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restx import fields, marshal_with, reqparse
+from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.web import api
diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py
index 3c133499b7..b2a887a0de 100644
--- a/api/controllers/web/site.py
+++ b/api/controllers/web/site.py
@@ -1,4 +1,4 @@
-from flask_restful import fields, marshal_with
+from flask_restx import fields, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index 590fd3f2c7..331587cc28 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse
+from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError
from controllers.web import api
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index ae6f14a689..94fa5d5626 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -2,7 +2,7 @@ from datetime import UTC, datetime
from functools import wraps
from flask import request
-from flask_restful import Resource
+from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index ad9b625350..f7c83f927f 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -512,7 +512,6 @@ class BaseAgentRunner(AppRunner):
if not file_objs:
return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
@@ -520,4 +519,6 @@ class BaseAgentRunner(AppRunner):
image_detail_config=image_detail_config,
)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=message.query))
+
return UserPromptMessage(content=prompt_message_contents)
diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index 565fb42478..6cb1077126 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -197,7 +197,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
- except json.JSONDecodeError:
+ except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py
index 5ff89bdacb..4d1d94eadc 100644
--- a/api/core/agent/cot_chat_agent_runner.py
+++ b/api/core/agent/cot_chat_agent_runner.py
@@ -39,9 +39,6 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
- prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=query))
-
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
@@ -52,6 +49,8 @@ class CotChatAgentRunner(CotAgentRunner):
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
+
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
@@ -59,6 +58,7 @@ class CotChatAgentRunner(CotAgentRunner):
image_detail_config=image_detail_config,
)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index 4df71ce9de..9eb853aa74 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -126,8 +126,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
- except json.JSONDecodeError:
- # ensure ascii to avoid encoding error
+ except TypeError:
+ # fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
@@ -153,8 +153,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
- except json.JSONDecodeError:
- # ensure ascii to avoid encoding error
+ except TypeError:
+ # fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
@@ -395,9 +395,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
- prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=query))
-
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
@@ -408,6 +405,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
+
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
@@ -415,6 +414,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
image_detail_config=image_detail_config,
)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 0df0aa59b2..0db1d52779 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -167,7 +167,7 @@ class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
- completion_params: dict[str, Any] = {}
+ completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py
index 34a1da2227..1a89237333 100644
--- a/api/core/app/apps/common/workflow_response_converter.py
+++ b/api/core/app/apps/common/workflow_response_converter.py
@@ -50,6 +50,7 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
+from libs.datetime_utils import naive_utc_now
from models import (
Account,
CreatorUserRole,
@@ -399,7 +400,7 @@ class WorkflowResponseConverter:
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
- elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
+ elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
@@ -478,7 +479,7 @@ class WorkflowResponseConverter:
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
- elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
+ elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py
index 42e6a1519c..d663dbb175 100644
--- a/api/core/app/entities/queue_entities.py
+++ b/api/core/app/entities/queue_entities.py
@@ -610,7 +610,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.ERROR
- error: Any = None
+ error: Optional[Any] = None
class QueuePingEvent(AppQueueEvent):
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index 25c889e922..a1c0368354 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -142,7 +142,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
- metadata: dict = {}
+ metadata: dict = Field(default_factory=dict)
files: Optional[Sequence[Mapping[str, Any]]] = None
@@ -261,7 +261,7 @@ class NodeStartStreamResponse(StreamResponse):
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
created_at: int
- extras: dict = {}
+ extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
@@ -503,7 +503,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
- extras: dict = {}
+ extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
@@ -531,7 +531,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
index: int
created_at: int
pre_iteration_output: Optional[Any] = None
- extras: dict = {}
+ extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
@@ -590,7 +590,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
- extras: dict = {}
+ extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
@@ -618,7 +618,7 @@ class LoopNodeNextStreamResponse(StreamResponse):
index: int
created_at: int
pre_loop_output: Optional[Any] = None
- extras: dict = {}
+ extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
@@ -764,7 +764,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
conversation_id: str
message_id: str
answer: str
- metadata: dict = {}
+ metadata: dict = Field(default_factory=dict)
created_at: int
data: Data
@@ -784,7 +784,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
mode: str
message_id: str
answer: str
- metadata: dict = {}
+ metadata: dict = Field(default_factory=dict)
created_at: int
data: Data
diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py
index 014c7fd4f5..8c0a442158 100644
--- a/api/core/app/task_pipeline/based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py
@@ -52,7 +52,8 @@ class BasedGenerateTaskPipeline:
elif isinstance(e, InvokeError | ValueError):
err = e
else:
- err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
+ description = getattr(e, "description", None)
+ err = Exception(description if description is not None else str(e))
if not message_id or not session:
return err
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 8bfbd82e1f..646e0e21e9 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -1,4 +1,3 @@
-import datetime
import json
import logging
from collections import defaultdict
@@ -29,6 +28,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.provider import (
LoadBalancingModelConfig,
Provider,
@@ -261,7 +261,7 @@ class ProviderConfiguration(BaseModel):
if provider_record:
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
- provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ provider_record.updated_at = naive_utc_now()
db.session.commit()
else:
provider_record = Provider()
@@ -426,7 +426,7 @@ class ProviderConfiguration(BaseModel):
if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
- provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ provider_model_record.updated_at = naive_utc_now()
db.session.commit()
else:
provider_model_record = ProviderModel()
@@ -501,7 +501,7 @@ class ProviderConfiguration(BaseModel):
if model_setting:
model_setting.enabled = True
- model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ model_setting.updated_at = naive_utc_now()
db.session.commit()
else:
model_setting = ProviderModelSetting()
@@ -526,7 +526,7 @@ class ProviderConfiguration(BaseModel):
if model_setting:
model_setting.enabled = False
- model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ model_setting.updated_at = naive_utc_now()
db.session.commit()
else:
model_setting = ProviderModelSetting()
@@ -599,7 +599,7 @@ class ProviderConfiguration(BaseModel):
if model_setting:
model_setting.load_balancing_enabled = True
- model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ model_setting.updated_at = naive_utc_now()
db.session.commit()
else:
model_setting = ProviderModelSetting()
@@ -638,7 +638,7 @@ class ProviderConfiguration(BaseModel):
if model_setting:
model_setting.load_balancing_enabled = False
- model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ model_setting.updated_at = naive_utc_now()
db.session.commit()
else:
model_setting = ProviderModelSetting()
diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py
index 557f7eb1ed..ae4671a381 100644
--- a/api/core/extension/extensible.py
+++ b/api/core/extension/extensible.py
@@ -17,7 +17,7 @@ class ExtensionModule(enum.Enum):
class ModuleExtension(BaseModel):
- extension_class: Any = None
+ extension_class: Optional[Any] = None
name: str
label: Optional[dict] = None
form_schema: Optional[list] = None
diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py
index 9eb9e0306b..50c3f9b5f4 100644
--- a/api/core/extension/extension.py
+++ b/api/core/extension/extension.py
@@ -38,6 +38,7 @@ class Extension:
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
+ assert module_extension.extension_class is not None
t: type = module_extension.extension_class
return t
diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py
index df42837796..5cd0ea5c66 100644
--- a/api/core/helper/trace_id_helper.py
+++ b/api/core/helper/trace_id_helper.py
@@ -1,3 +1,4 @@
+import contextlib
import re
from collections.abc import Mapping
from typing import Any, Optional
@@ -97,10 +98,8 @@ def parse_traceparent_header(traceparent: str) -> Optional[str]:
Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
"""
- try:
+ with contextlib.suppress(Exception):
parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
- except Exception:
- pass
return None
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index b40278c76b..9876194608 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -1,5 +1,4 @@
import concurrent.futures
-import datetime
import json
import logging
import re
@@ -9,7 +8,6 @@ import uuid
from typing import Any, Optional, cast
from flask import current_app
-from flask_login import current_user
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -35,6 +33,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
+from libs.datetime_utils import naive_utc_now
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
@@ -88,7 +87,7 @@ class IndexingRunner:
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except ObjectDeletedError:
logging.warning("Document deleted, document id: %s", dataset_document.id)
@@ -96,7 +95,7 @@ class IndexingRunner:
logging.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
@@ -151,13 +150,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument):
@@ -226,13 +225,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
- dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.stopped_at = naive_utc_now()
db.session.commit()
def indexing_estimate(
@@ -295,7 +294,7 @@ class IndexingRunner:
text_docs,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
- tenant_id=current_user.current_tenant_id,
+ tenant_id=tenant_id,
doc_language=doc_language,
preview=True,
)
@@ -401,7 +400,7 @@ class IndexingRunner:
after_indexing_status="splitting",
extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
- DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DatasetDocument.parsing_completed_at: naive_utc_now(),
},
)
@@ -584,7 +583,7 @@ class IndexingRunner:
after_indexing_status="completed",
extra_update_params={
DatasetDocument.tokens: tokens,
- DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DatasetDocument.completed_at: naive_utc_now(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None,
},
@@ -609,7 +608,7 @@ class IndexingRunner:
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
- DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.completed_at: naive_utc_now(),
}
)
@@ -640,7 +639,7 @@ class IndexingRunner:
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
- DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.completed_at: naive_utc_now(),
}
)
@@ -728,7 +727,7 @@ class IndexingRunner:
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
# update document status to indexing
- cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ cur_time = naive_utc_now()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
@@ -743,7 +742,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.indexing_at: naive_utc_now(),
},
)
pass
diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py
index 3f98aa94ae..031f01f411 100644
--- a/api/core/mcp/session/base_session.py
+++ b/api/core/mcp/session/base_session.py
@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
-from typing import Any, Generic, Self, TypeVar
+from typing import Any, Generic, Optional, Self, TypeVar
from httpx import HTTPStatusError
from pydantic import BaseModel
@@ -209,7 +209,7 @@ class BaseSession(
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
- metadata: MessageMetadata = None,
+ metadata: Optional[MessageMetadata] = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py
index 99d985a781..49aa8e4498 100644
--- a/api/core/mcp/types.py
+++ b/api/core/mcp/types.py
@@ -1173,7 +1173,7 @@ class SessionMessage:
"""A message with specific metadata for transport-specific features."""
message: JSONRPCMessage
- metadata: MessageMetadata = None
+ metadata: Optional[MessageMetadata] = None
class OAuthClientMetadata(BaseModel):
diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py
index ace2c1f770..dc6032e405 100644
--- a/api/core/model_runtime/entities/llm_entities.py
+++ b/api/core/model_runtime/entities/llm_entities.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
-from typing import Any, Optional
+from typing import Any, Optional, TypedDict, Union
from pydantic import BaseModel, Field
@@ -18,6 +20,26 @@ class LLMMode(StrEnum):
CHAT = "chat"
+class LLMUsageMetadata(TypedDict, total=False):
+ """
+ TypedDict for LLM usage metadata.
+ All fields are optional.
+ """
+
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ prompt_unit_price: Union[float, str]
+ completion_unit_price: Union[float, str]
+ total_price: Union[float, str]
+ currency: str
+ prompt_price_unit: Union[float, str]
+ completion_price_unit: Union[float, str]
+ prompt_price: Union[float, str]
+ completion_price: Union[float, str]
+ latency: float
+
+
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
@@ -54,23 +76,27 @@ class LLMUsage(ModelUsage):
)
@classmethod
- def from_metadata(cls, metadata: dict) -> "LLMUsage":
+ def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
- metadata: Dictionary containing usage metadata
+ metadata: TypedDict containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
- total_tokens = metadata.get("total_tokens", 0)
+ prompt_tokens = metadata.get("prompt_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
- if total_tokens > 0 and completion_tokens == 0:
- completion_tokens = total_tokens
+ total_tokens = metadata.get("total_tokens", 0)
+
+ # If total_tokens is not provided but prompt and completion tokens are,
+ # calculate total_tokens
+ if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
+ total_tokens = prompt_tokens + completion_tokens
return cls(
- prompt_tokens=metadata.get("prompt_tokens", 0),
+ prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
@@ -84,7 +110,7 @@ class LLMUsage(ModelUsage):
latency=metadata.get("latency", 0.0),
)
- def plus(self, other: "LLMUsage") -> "LLMUsage":
+ def plus(self, other: LLMUsage) -> LLMUsage:
"""
Add two LLMUsage instances together.
@@ -109,7 +135,7 @@ class LLMUsage(ModelUsage):
latency=self.latency + other.latency,
)
- def __add__(self, other: "LLMUsage") -> "LLMUsage":
+ def __add__(self, other: LLMUsage) -> LLMUsage:
"""
Overload the + operator to add two LLMUsage instances.
diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py
index b7db0b78bc..68d30112d9 100644
--- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py
+++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py
@@ -1,10 +1,10 @@
import logging
from threading import Lock
-from typing import Any
+from typing import Any, Optional
logger = logging.getLogger(__name__)
-_tokenizer: Any = None
+_tokenizer: Optional[Any] = None
_lock = Lock()
diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py
index 332381555b..af51b72cd5 100644
--- a/api/core/moderation/api/api.py
+++ b/api/core/moderation/api/api.py
@@ -1,6 +1,6 @@
from typing import Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@@ -11,7 +11,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
- inputs: dict = {}
+ inputs: dict = Field(default_factory=dict)
query: str = ""
diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py
index d8c392d097..99bd0049c0 100644
--- a/api/core/moderation/base.py
+++ b/api/core/moderation/base.py
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from core.extension.extensible import Extensible, ExtensionModule
@@ -16,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
- inputs: dict = {}
+ inputs: dict = Field(default_factory=dict)
query: str = ""
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 0f0fe65f27..16c145f936 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -125,11 +125,11 @@ class AdvancedPromptTransform(PromptTransform):
if files:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=prompt))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
@@ -196,16 +196,17 @@ class AdvancedPromptTransform(PromptTransform):
query = parser.format(prompt_inputs)
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files and query is not None:
- prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=query))
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
@@ -215,27 +216,27 @@ class AdvancedPromptTransform(PromptTransform):
last_message = prompt_messages[-1] if prompt_messages else None
if last_message and last_message.role == PromptMessageRole.USER:
# get last user message content and add files
- prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content)))
last_message.content = prompt_message_contents
else:
- prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=""))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
- prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
elif query:
diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py
index e19c6419ca..13f4163d80 100644
--- a/api/core/prompt/simple_prompt_transform.py
+++ b/api/core/prompt/simple_prompt_transform.py
@@ -265,11 +265,11 @@ class SimplePromptTransform(PromptTransform):
) -> UserPromptMessage:
if files:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
- prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
+ prompt_message_contents.append(TextPromptMessageContent(data=prompt))
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 9250497d29..39fec951bb 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -1,3 +1,4 @@
+import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
@@ -624,14 +625,12 @@ class ProviderManager:
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
- try:
+ with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
- except ValueError:
- pass
# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
@@ -672,14 +671,12 @@ class ProviderManager:
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
- try:
+ with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
- except ValueError:
- pass
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
index 2df17181a4..bb61b71bb1 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
@@ -105,9 +105,11 @@ class AnalyticdbVectorBySql:
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
+ conn = cur.connection
try:
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
except Exception as e:
+ conn.rollback()
raise RuntimeError(
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
) from e
@@ -115,6 +117,7 @@ class AnalyticdbVectorBySql:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
+ conn.rollback()
if "already exists" not in str(e):
raise e
cur.execute(
diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
index 1059b855a2..6e8077ffd9 100644
--- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
+++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
@@ -1,3 +1,4 @@
+import contextlib
import json
import logging
import queue
@@ -214,10 +215,8 @@ class ClickzettaConnectionPool:
return connection
else:
# Connection expired or invalid, close it
- try:
+ with contextlib.suppress(Exception):
connection.close()
- except Exception:
- pass
# No valid connection found, create new one
return self._create_connection(config)
@@ -228,10 +227,8 @@ class ClickzettaConnectionPool:
if config_key not in self._pool_locks:
# Pool was cleaned up, just close the connection
- try:
+ with contextlib.suppress(Exception):
connection.close()
- except Exception:
- pass
return
with self._pool_locks[config_key]:
@@ -243,10 +240,8 @@ class ClickzettaConnectionPool:
logger.debug("Returned ClickZetta connection to pool")
else:
# Pool full or connection invalid, close it
- try:
+ with contextlib.suppress(Exception):
connection.close()
- except Exception:
- pass
def _cleanup_expired_connections(self) -> None:
"""Clean up expired connections from all pools."""
@@ -265,10 +260,8 @@ class ClickzettaConnectionPool:
if current_time - last_used < self._connection_timeout:
valid_connections.append((connection, last_used))
else:
- try:
+ with contextlib.suppress(Exception):
connection.close()
- except Exception:
- pass
self._pools[config_key] = valid_connections
@@ -299,10 +292,8 @@ class ClickzettaConnectionPool:
with self._pool_locks[config_key]:
pool = self._pools[config_key]
for connection, _ in pool:
- try:
+ with contextlib.suppress(Exception):
connection.close()
- except Exception:
- pass
pool.clear()
diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
index d64f366e0e..112f07844c 100644
--- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py
+++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
@@ -101,7 +101,7 @@ class MilvusVector(BaseVector):
if "Zilliz Cloud" in milvus_version:
return True
# For standard Milvus installations, check version number
- return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
+ return version.parse(milvus_version) >= version.parse("2.5.0")
except Exception as e:
logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e))
return False
diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
index 8efe105bbf..556d03940e 100644
--- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
+++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
@@ -152,7 +152,7 @@ class OceanBaseVector(BaseVector):
ob_full_version = result.fetchone()[0]
ob_version = ob_full_version.split()[1]
logger.debug("Current OceanBase version is %s", ob_version)
- return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version
+ return version.parse(ob_version) >= version.parse("4.3.5.1")
except Exception as e:
logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e))
return False
diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py
index a3b35458df..7cc554c74d 100644
--- a/api/core/rag/extractor/excel_extractor.py
+++ b/api/core/rag/extractor/excel_extractor.py
@@ -34,9 +34,8 @@ class ExcelExtractor(BaseExtractor):
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
data = sheet.values
- try:
- cols = next(data)
- except StopIteration:
+ cols = next(data, None)
+ if cols is None:
continue
df = pd.DataFrame(data, columns=cols)
diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py
index 04033dec3f..7dfe2e357c 100644
--- a/api/core/rag/extractor/pdf_extractor.py
+++ b/api/core/rag/extractor/pdf_extractor.py
@@ -1,5 +1,6 @@
"""Abstract interface for document loader implementations."""
+import contextlib
from collections.abc import Iterator
from typing import Optional, cast
@@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor):
def extract(self) -> list[Document]:
plaintext_file_exists = False
if self._file_cache_key:
- try:
+ with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
- except FileNotFoundError:
- pass
documents = list(self.load())
text_list = []
for document in documents:
diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
index f1fa5dde5c..856a9bce18 100644
--- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
@@ -1,4 +1,5 @@
import base64
+import contextlib
import logging
from typing import Optional
@@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
elements = partition_email(filename=self._file_path)
# noinspection PyBroadException
- try:
+ with contextlib.suppress(Exception):
for element in elements:
element_text = element.text.strip()
@@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor):
element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
element.text = soup.get_text()
- except Exception:
- pass
from unstructured.chunking.title import chunk_by_title
diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py
index 21fbb2100f..da03fc67a6 100644
--- a/api/core/rag/extractor/watercrawl/provider.py
+++ b/api/core/rag/extractor/watercrawl/provider.py
@@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
-from typing import Any
+from typing import Any, Optional
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
@@ -9,7 +9,7 @@ class WaterCrawlProvider:
def __init__(self, api_key, base_url: str | None = None):
self.client = WaterCrawlAPIClient(api_key, base_url)
- def crawl_url(self, url, options: dict | Any = None) -> dict:
+ def crawl_url(self, url, options: Optional[dict | Any] = None) -> dict:
options = options or {}
spider_options = {
"max_depth": 1,
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 0eff7c186a..f3b162e3d3 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -1,6 +1,5 @@
"""Abstract interface for document loader implementations."""
-import datetime
import logging
import mimetypes
import os
@@ -19,6 +18,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
+from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
@@ -117,10 +117,10 @@ class WordExtractor(BaseExtractor):
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
- created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
- used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ used_at=naive_utc_now(),
)
db.session.add(upload_file)
diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py
index 04a3428ad8..ff63a6780e 100644
--- a/api/core/rag/models/document.py
+++ b/api/core/rag/models/document.py
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class ChildDocument(BaseModel):
@@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
- metadata: dict = {}
+ metadata: dict = Field(default_factory=dict)
class Document(BaseModel):
@@ -28,7 +28,7 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
- metadata: dict = {}
+ metadata: dict = Field(default_factory=dict)
provider: Optional[str] = "dify"
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index a25bc65646..cd4af72832 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -1012,7 +1012,7 @@ class DatasetRetrieval:
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
- if value is None:
+ if value is None and condition not in ("empty", "not empty"):
return
key = f"{metadata_name}_{sequence}"
diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py
index 09c775f3a6..854c122331 100644
--- a/api/core/repositories/factory.py
+++ b/api/core/repositories/factory.py
@@ -5,10 +5,7 @@ This module provides a Django-like settings system for repository implementation
allowing users to configure different repository backends through string paths.
"""
-import importlib
-import inspect
-import logging
-from typing import Protocol, Union
+from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -16,12 +13,11 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from libs.module_loading import import_string
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
-logger = logging.getLogger(__name__)
-
class RepositoryImportError(Exception):
"""Raised when a repository implementation cannot be imported or instantiated."""
@@ -37,96 +33,6 @@ class DifyCoreRepositoryFactory:
are specified as module paths (e.g., 'module.submodule.ClassName').
"""
- @staticmethod
- def _import_class(class_path: str) -> type:
- """
- Import a class from a module path string.
-
- Args:
- class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
-
- Returns:
- The imported class
-
- Raises:
- RepositoryImportError: If the class cannot be imported
- """
- try:
- module_path, class_name = class_path.rsplit(".", 1)
- module = importlib.import_module(module_path)
- repo_class = getattr(module, class_name)
- assert isinstance(repo_class, type)
- return repo_class
- except (ValueError, ImportError, AttributeError) as e:
- raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
-
- @staticmethod
- def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
- """
- Validate that a class implements the expected repository interface.
-
- Args:
- repository_class: The class to validate
- expected_interface: The expected interface/protocol
-
- Raises:
- RepositoryImportError: If the class doesn't implement the interface
- """
- # Check if the class has all required methods from the protocol
- required_methods = [
- method
- for method in dir(expected_interface)
- if not method.startswith("_") and callable(getattr(expected_interface, method, None))
- ]
-
- missing_methods = []
- for method_name in required_methods:
- if not hasattr(repository_class, method_name):
- missing_methods.append(method_name)
-
- if missing_methods:
- raise RepositoryImportError(
- f"Repository class '{repository_class.__name__}' does not implement required methods "
- f"{missing_methods} from interface '{expected_interface.__name__}'"
- )
-
- @staticmethod
- def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
- """
- Validate that a repository class constructor accepts required parameters.
- Args:
- repository_class: The class to validate
- required_params: List of required parameter names
- Raises:
- RepositoryImportError: If the constructor doesn't accept required parameters
- """
-
- try:
- # MyPy may flag the line below with the following error:
- #
- # > Accessing "__init__" on an instance is unsound, since
- # > instance.__init__ could be from an incompatible subclass.
- #
- # Despite this, we need to ensure that the constructor of `repository_class`
- # has a compatible signature.
- signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
- param_names = list(signature.parameters.keys())
-
- # Remove 'self' parameter
- if "self" in param_names:
- param_names.remove("self")
-
- missing_params = [param for param in required_params if param not in param_names]
- if missing_params:
- raise RepositoryImportError(
- f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
- f"{missing_params}. Expected parameters: {required_params}"
- )
- except Exception as e:
- raise RepositoryImportError(
- f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
- ) from e
-
@classmethod
def create_workflow_execution_repository(
cls,
@@ -151,24 +57,16 @@ class DifyCoreRepositoryFactory:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
- logger.debug("Creating WorkflowExecutionRepository from: %s", class_path)
try:
- repository_class = cls._import_class(class_path)
- cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
-
- # All repository types now use the same constructor parameters
+ repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
- except RepositoryImportError:
- # Re-raise our custom errors as-is
- raise
- except Exception as e:
- logger.exception("Failed to create WorkflowExecutionRepository")
+ except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
@classmethod
@@ -195,24 +93,16 @@ class DifyCoreRepositoryFactory:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
- logger.debug("Creating WorkflowNodeExecutionRepository from: %s", class_path)
try:
- repository_class = cls._import_class(class_path)
- cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
-
- # All repository types now use the same constructor parameters
+ repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
- except RepositoryImportError:
- # Re-raise our custom errors as-is
- raise
- except Exception as e:
- logger.exception("Failed to create WorkflowNodeExecutionRepository")
+ except (ImportError, Exception) as e:
raise RepositoryImportError(
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
) from e
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 5ffba07b44..df599a09a3 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -1,4 +1,5 @@
import base64
+import contextlib
import enum
from collections.abc import Mapping
from enum import Enum
@@ -227,10 +228,8 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def decode_blob_message(cls, v):
if isinstance(v, dict) and "blob" in v:
- try:
+ with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
- except Exception:
- pass
return v
@field_serializer("message")
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 83444c02d8..10db4d9503 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -1,3 +1,4 @@
+import contextlib
import json
from collections.abc import Generator, Iterable
from copy import deepcopy
@@ -69,10 +70,8 @@ class ToolEngine:
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
- try:
+ with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters)
- except Exception:
- pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
@@ -270,14 +269,12 @@ class ToolEngine:
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
else:
- try:
+ with contextlib.suppress(Exception):
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
mimetype = guess_type_result
- except Exception:
- pass
if not mimetype:
mimetype = "image/jpeg"
diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py
index aceba6e69f..3a9391dbb1 100644
--- a/api/core/tools/utils/configuration.py
+++ b/api/core/tools/utils/configuration.py
@@ -1,3 +1,4 @@
+import contextlib
from copy import deepcopy
from typing import Any
@@ -137,11 +138,9 @@ class ToolParameterConfigurationManager:
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
- try:
- has_secret_input = True
+ has_secret_input = True
+ with contextlib.suppress(Exception):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
- except Exception:
- pass
if has_secret_input:
cache.set(parameters)
diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py
index 5fdfd3b9d1..d771293e11 100644
--- a/api/core/tools/utils/encryption.py
+++ b/api/core/tools/utils/encryption.py
@@ -1,3 +1,4 @@
+import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
@@ -111,14 +112,12 @@ class ProviderConfigEncrypter:
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
- try:
+ with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
- except Exception:
- pass
self.provider_config_cache.set(data)
return data
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
index 770c0ef7bd..d8403c2e15 100644
--- a/api/core/tools/utils/web_reader_tool.py
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -80,7 +80,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
else:
content = response.text
- article = extract_using_readability(content)
+ article = extract_using_readabilipy(content)
if not article.text:
return ""
@@ -101,7 +101,7 @@ class Article:
text: Sequence[dict]
-def extract_using_readability(html: str):
+def extract_using_readabilipy(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
article = Article(
title=json_article.get("title") or "",
diff --git a/api/core/variables/types.py b/api/core/variables/types.py
index d28fb11401..6629056042 100644
--- a/api/core/variables/types.py
+++ b/api/core/variables/types.py
@@ -126,7 +126,7 @@ class SegmentType(StrEnum):
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
- elif self == SegmentType.NUMBER:
+ elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
return isinstance(value, str)
@@ -166,7 +166,6 @@ _ARRAY_TYPES = frozenset(
]
)
-
_NUMERICAL_TYPES = frozenset(
[
SegmentType.NUMBER,
diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py
index 781be4b3c6..f00dc11aa6 100644
--- a/api/core/workflow/entities/workflow_execution.py
+++ b/api/core/workflow/entities/workflow_execution.py
@@ -6,12 +6,14 @@ implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
-from datetime import UTC, datetime
+from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
+from libs.datetime_utils import naive_utc_now
+
class WorkflowType(StrEnum):
"""
@@ -60,7 +62,7 @@ class WorkflowExecution(BaseModel):
Calculate elapsed time in seconds.
If workflow is not finished, use current time.
"""
- end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
+ end_time = self.finished_at or naive_utc_now()
return (end_time - self.started_at).total_seconds()
@classmethod
diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
index a62ffe46c9..e2ec7b17f0 100644
--- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py
+++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
@@ -22,7 +22,7 @@ class GraphRuntimeState(BaseModel):
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
- outputs: dict[str, Any] = {}
+ outputs: dict[str, Any] = Field(default_factory=dict)
node_run_steps: int = 0
"""node run steps"""
diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py
index f2d9c98936..a4ddfafab5 100644
--- a/api/core/workflow/graph_engine/entities/runtime_route_state.py
+++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py
@@ -1,5 +1,5 @@
import uuid
-from datetime import UTC, datetime
+from datetime import datetime
from enum import Enum
from typing import Optional
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
+from libs.datetime_utils import naive_utc_now
class RouteNodeState(BaseModel):
@@ -71,7 +72,7 @@ class RouteNodeState(BaseModel):
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
- self.finished_at = datetime.now(UTC).replace(tzinfo=None)
+ self.finished_at = naive_utc_now()
class RuntimeRouteState(BaseModel):
@@ -89,7 +90,7 @@ class RuntimeRouteState(BaseModel):
:param node_id: node id
"""
- state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
+ state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
self.node_state_mapping[state.id] = state
return state
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index b9663d32f7..03b920ccbb 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -6,7 +6,6 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
-from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -51,6 +50,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
+from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -640,7 +640,7 @@ class GraphEngine:
while should_continue_retry and retries <= max_retries:
try:
# run node
- retry_start_at = datetime.now(UTC).replace(tzinfo=None)
+ retry_start_at = naive_utc_now()
# yield control to other threads
time.sleep(0.001)
event_stream = node.run()
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 2b6382a8a6..144f036aa4 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -13,8 +13,9 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
+from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
@@ -558,7 +559,7 @@ class AgentNode(BaseNode):
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
- llm_usage = LLMUsage.from_metadata(msg_metadata)
+ llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
@@ -692,7 +693,13 @@ class AgentNode(BaseNode):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
- outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
+ outputs={
+ "text": text,
+ "usage": jsonable_encoder(llm_usage),
+ "files": ArrayFileSegment(value=files),
+ "json": json_output,
+ **variables,
+ },
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index c9f7fa1221..a5a578a6ff 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -12,6 +12,7 @@ from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
+from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities.variable_pool import VariablePool
@@ -228,7 +229,9 @@ class Executor:
files: dict[str, list[tuple[str | None, bytes, str]]] = {}
for key, files_in_segment in files_list:
for file in files_in_segment:
- if file.related_id is not None:
+ if file.related_id is not None or (
+ file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None
+ ):
file_tuple = (
file.filename,
file_manager.download(file),
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index def1e1cfa3..7f591a3ea9 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -4,7 +4,7 @@ import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
-from datetime import UTC, datetime
+from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
@@ -41,6 +41,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
+from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
@@ -179,7 +180,7 @@ class IterationNode(BaseNode):
thread_pool_id=self.thread_pool_id,
)
- start_at = datetime.now(UTC).replace(tzinfo=None)
+ start_at = naive_utc_now()
yield IterationRunStartedEvent(
iteration_id=self.id,
@@ -428,7 +429,7 @@ class IterationNode(BaseNode):
"""
run single iteration
"""
- iter_start_at = datetime.now(UTC).replace(tzinfo=None)
+ iter_start_at = naive_utc_now()
try:
rst = graph_engine.run()
@@ -505,7 +506,7 @@ class IterationNode(BaseNode):
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
- duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
+ duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
@@ -526,7 +527,7 @@ class IterationNode(BaseNode):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
- duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
+ duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
@@ -602,7 +603,7 @@ class IterationNode(BaseNode):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
- duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
+ duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 7303b68501..5e5c9f520e 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -175,7 +175,7 @@ class KnowledgeRetrievalNode(BaseNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
- with Session(db.engine) as session:
+ with sessionmaker(db.engine).begin() as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
@@ -183,7 +183,6 @@ class KnowledgeRetrievalNode(BaseNode):
operation="knowledge",
)
session.add(rate_limit_log)
- session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -389,6 +388,15 @@ class KnowledgeRetrievalNode(BaseNode):
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
+ "child_chunks": [
+ {
+ "id": str(getattr(chunk, "id", "")),
+ "content": str(getattr(chunk, "content", "")),
+ "position": int(getattr(chunk, "position", 0)),
+ "score": float(getattr(chunk, "score", 0.0)),
+ }
+ for chunk in (record.child_chunks or [])
+ ],
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
@@ -572,7 +580,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
- if value is None:
+ if value is None and condition not in ("empty", "not empty"):
return
key = f"{metadata_name}_{sequence}"
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index 4bb62d35a2..e6f8abeba0 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -13,7 +13,7 @@ class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
- completion_params: dict[str, Any] = {}
+ completion_params: dict[str, Any] = Field(default_factory=dict)
class ContextConfig(BaseModel):
diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py
index 0966c87a1d..2441e30c87 100644
--- a/api/core/workflow/nodes/llm/llm_utils.py
+++ b/api/core/workflow/nodes/llm/llm_utils.py
@@ -1,5 +1,4 @@
from collections.abc import Sequence
-from datetime import UTC, datetime
from typing import Optional, cast
from sqlalchemy import select, update
@@ -20,6 +19,7 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
+from libs.datetime_utils import naive_utc_now
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType
@@ -149,7 +149,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
.values(
quota_used=Provider.quota_used + used_quota,
- last_used=datetime.now(tz=UTC).replace(tzinfo=None),
+ last_used=naive_utc_now(),
)
)
session.execute(stmt)
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 9a288c6133..b2ab943129 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -2,7 +2,7 @@ import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
-from datetime import UTC, datetime
+from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
@@ -36,6 +36,7 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
+from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@@ -143,7 +144,7 @@ class LoopNode(BaseNode):
thread_pool_id=self.thread_pool_id,
)
- start_at = datetime.now(UTC).replace(tzinfo=None)
+ start_at = naive_utc_now()
condition_processor = ConditionProcessor()
# Start Loop event
@@ -171,7 +172,7 @@ class LoopNode(BaseNode):
try:
check_break_result = False
for i in range(loop_count):
- loop_start_time = datetime.now(UTC).replace(tzinfo=None)
+ loop_start_time = naive_utc_now()
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
@@ -185,7 +186,7 @@ class LoopNode(BaseNode):
start_at=start_at,
inputs=inputs,
)
- loop_end_time = datetime.now(UTC).replace(tzinfo=None)
+ loop_end_time = naive_utc_now()
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index 45c5e0a62c..49c4c142e1 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -1,3 +1,4 @@
+import contextlib
import json
import logging
import uuid
@@ -666,10 +667,8 @@ class ParameterExtractorNode(BaseNode):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
- try:
+ with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
- except Exception:
- pass
logger.info("extra error: %s", result)
return None
@@ -686,10 +685,9 @@ class ParameterExtractorNode(BaseNode):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
- try:
+ with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
- except Exception:
- pass
+
logger.info("extra error: %s", result)
return None
diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh
index da147fe895..e21092349e 100755
--- a/api/docker/entrypoint.sh
+++ b/api/docker/entrypoint.sh
@@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
- -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
+ -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py
index 00a66f50ad..bbc913b7cf 100644
--- a/api/events/event_handlers/clean_when_document_deleted.py
+++ b/api/events/event_handlers/clean_when_document_deleted.py
@@ -8,4 +8,6 @@ def handle(sender, **kwargs):
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
+ assert dataset_id is not None
+ assert doc_form is not None
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index c607161e2a..1b0321f42e 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -1,3 +1,4 @@
+import contextlib
import logging
import time
@@ -38,12 +39,11 @@ def handle(sender, **kwargs):
db.session.add(document)
db.session.commit()
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logging.info(click.style(str(ex), fg="yellow"))
- except Exception:
- pass
+ with contextlib.suppress(Exception):
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(documents)
+ end_at = time.perf_counter()
+ logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logging.info(click.style(str(ex), fg="yellow"))
diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py
index 2ed42c71ea..f01dd58900 100644
--- a/api/events/event_handlers/update_provider_when_message_created.py
+++ b/api/events/event_handlers/update_provider_when_message_created.py
@@ -188,7 +188,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
# Use SQLAlchemy's context manager for transaction management
# This automatically handles commit/rollback
- with Session(db.engine) as session:
+ with Session(db.engine) as session, session.begin():
# Use a single transaction for all updates
for update_operation in updates_to_perform:
filters = update_operation.filters
diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py
index a4d013ffc0..1024fd9ce6 100644
--- a/api/extensions/ext_blueprints.py
+++ b/api/extensions/ext_blueprints.py
@@ -29,7 +29,6 @@ def init_app(app: DifyApp):
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
-
app.register_blueprint(web_bp)
CORS(
@@ -40,10 +39,13 @@ def init_app(app: DifyApp):
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
-
app.register_blueprint(console_app_bp)
- CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
+ CORS(
+ files_bp,
+ allow_headers=["Content-Type"],
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ )
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 00e0bd9a16..fb5352ca8f 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -3,8 +3,8 @@ from datetime import timedelta
from typing import Any, Optional
import pytz
-from celery import Celery, Task # type: ignore
-from celery.schedules import crontab # type: ignore
+from celery import Celery, Task
+from celery.schedules import crontab
from configs import dify_config
from dify_app import DifyApp
@@ -66,7 +66,6 @@ def init_app(app: DifyApp) -> Celery:
task_cls=FlaskTask,
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
- task_ignore_result=True,
)
celery_app.conf.update(
@@ -77,6 +76,7 @@ def init_app(app: DifyApp) -> Celery:
worker_task_log_format=dify_config.LOG_FORMAT,
worker_hijack_root_logger=False,
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
+ task_ignore_result=True,
)
# Apply SSL configuration if enabled
diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py
index 9b18e25eaa..9e5c71fb1d 100644
--- a/api/extensions/ext_login.py
+++ b/api/extensions/ext_login.py
@@ -20,6 +20,10 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
+ # Skip authentication for documentation endpoints
+ if request.path.endswith("/docs") or request.path.endswith("/swagger.json"):
+ return None
+
auth_header = request.headers.get("Authorization", "")
auth_token: str | None = None
if auth_header:
diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py
index a8f025a750..544a2dc625 100644
--- a/api/extensions/ext_otel.py
+++ b/api/extensions/ext_otel.py
@@ -1,4 +1,5 @@
import atexit
+import contextlib
import logging
import os
import platform
@@ -7,7 +8,7 @@ import sys
from typing import Union
import flask
-from celery.signals import worker_init # type: ignore
+from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
from configs import dify_config
@@ -106,7 +107,7 @@ def init_app(app: DifyApp):
"""Custom logging handler that creates spans for logging.exception() calls"""
def emit(self, record: logging.LogRecord):
- try:
+ with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
@@ -126,9 +127,6 @@ def init_app(app: DifyApp):
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
- except Exception:
- pass
-
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py
index f5f544679f..1b22886fc1 100644
--- a/api/extensions/ext_redis.py
+++ b/api/extensions/ext_redis.py
@@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
import redis
from redis import RedisError
@@ -246,7 +246,7 @@ def init_app(app: DifyApp):
app.extensions["redis"] = redis_client
-def redis_fallback(default_return: Any = None):
+def redis_fallback(default_return: Optional[Any] = None):
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py
index 379dcc6d16..38835d5ac7 100644
--- a/api/fields/annotation_fields.py
+++ b/api/fields/annotation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@@ -11,6 +11,12 @@ annotation_fields = {
# 'account': fields.Nested(simple_account_fields, allow_null=True)
}
+
+def build_annotation_model(api_or_ns: Api | Namespace):
+ """Build the annotation model for the API or Namespace."""
+ return api_or_ns.model("Annotation", annotation_fields)
+
+
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py
index a85d4a34db..a2dda1dc15 100644
--- a/api/fields/api_based_extension_fields.py
+++ b/api/fields/api_based_extension_fields.py
@@ -1,10 +1,10 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import TimestampField
class HiddenAPIKey(fields.Raw):
- def output(self, key, obj):
+ def output(self, key, obj, **kwargs):
api_key = obj.api_key
# If the length of the api_key is less than 8 characters, show the first and last characters
if len(api_key) <= 8:
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index 1a5fcabf97..1f14d663b8 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -1,6 +1,6 @@
import json
-from flask_restful import fields
+from flask_restx import fields
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index 370e8a5a58..ecc267cf38 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -45,6 +45,12 @@ message_file_fields = {
"upload_file_id": fields.String(default=None),
}
+
+def build_message_file_model(api_or_ns: Api | Namespace):
+ """Build the message file fields for the API or Namespace."""
+ return api_or_ns.model("MessageFile", message_file_fields)
+
+
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
@@ -209,3 +215,22 @@ conversation_infinite_scroll_pagination_fields = {
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(simple_conversation_fields)),
}
+
+
+def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
+ """Build the conversation infinite scroll pagination model for the API or Namespace."""
+ simple_conversation_model = build_simple_conversation_model(api_or_ns)
+
+ copied_fields = conversation_infinite_scroll_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
+ return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
+
+
+def build_conversation_delete_model(api_or_ns: Api | Namespace):
+ """Build the conversation delete model for the API or Namespace."""
+ return api_or_ns.model("ConversationDelete", conversation_delete_fields)
+
+
+def build_simple_conversation_model(api_or_ns: Api | Namespace):
+ """Build the simple conversation model for the API or Namespace."""
+ return api_or_ns.model("SimpleConversation", simple_conversation_fields)
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index c5a0c9a49d..7d5e311591 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@@ -27,3 +27,19 @@ conversation_variable_infinite_scroll_pagination_fields = {
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(conversation_variable_fields)),
}
+
+
+def build_conversation_variable_model(api_or_ns: Api | Namespace):
+ """Build the conversation variable model for the API or Namespace."""
+ return api_or_ns.model("ConversationVariable", conversation_variable_fields)
+
+
+def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
+ """Build the conversation variable infinite scroll pagination model for the API or Namespace."""
+ # Build the nested variable model first
+ conversation_variable_model = build_conversation_variable_model(api_or_ns)
+
+ copied_fields = conversation_variable_infinite_scroll_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(conversation_variable_model))
+
+ return api_or_ns.model("ConversationVariableInfiniteScrollPagination", copied_fields)
diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py
index 071071376f..93f6e447dc 100644
--- a/api/fields/data_source_fields.py
+++ b/api/fields/data_source_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import TimestampField
diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py
index 32a88cc5db..5a3082516e 100644
--- a/api/fields/dataset_fields.py
+++ b/api/fields/dataset_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import TimestampField
diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py
index 7fd43e8dbe..9be59f7454 100644
--- a/api/fields/document_fields.py
+++ b/api/fields/document_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py
index 99e529f9d1..ea43e3b5fd 100644
--- a/api/fields/end_user_fields.py
+++ b/api/fields/end_user_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@@ -6,3 +6,7 @@ simple_end_user_fields = {
"is_anonymous": fields.Boolean,
"session_id": fields.String,
}
+
+
+def build_simple_end_user_model(api_or_ns: Api | Namespace):
+ return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py
index 8b4839ef97..dd359e2f5f 100644
--- a/api/fields/file_fields.py
+++ b/api/fields/file_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@@ -11,6 +11,19 @@ upload_config_fields = {
"workflow_file_upload_limit": fields.Integer,
}
+
+def build_upload_config_model(api_or_ns: Api | Namespace):
+ """Build the upload config model for the API or Namespace.
+
+ Args:
+ api_or_ns: Flask-RestX Api or Namespace instance
+
+ Returns:
+ The registered model
+ """
+ return api_or_ns.model("UploadConfig", upload_config_fields)
+
+
file_fields = {
"id": fields.String,
"name": fields.String,
@@ -22,12 +35,37 @@ file_fields = {
"preview_url": fields.String,
}
+
+def build_file_model(api_or_ns: Api | Namespace):
+ """Build the file model for the API or Namespace.
+
+ Args:
+ api_or_ns: Flask-RestX Api or Namespace instance
+
+ Returns:
+ The registered model
+ """
+ return api_or_ns.model("File", file_fields)
+
+
remote_file_info_fields = {
"file_type": fields.String(attribute="file_type"),
"file_length": fields.Integer(attribute="file_length"),
}
+def build_remote_file_info_model(api_or_ns: Api | Namespace):
+ """Build the remote file info model for the API or Namespace.
+
+ Args:
+ api_or_ns: Flask-RestX Api or Namespace instance
+
+ Returns:
+ The registered model
+ """
+ return api_or_ns.model("RemoteFileInfo", remote_file_info_fields)
+
+
file_fields_with_signed_url = {
"id": fields.String,
"name": fields.String,
@@ -38,3 +76,15 @@ file_fields_with_signed_url = {
"created_by": fields.String,
"created_at": TimestampField,
}
+
+
+def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
+ """Build the file with signed URL model for the API or Namespace.
+
+ Args:
+ api_or_ns: Flask-RestX Api or Namespace instance
+
+ Returns:
+ The registered model
+ """
+ return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url)
diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py
index 9d67999ea4..75bdff1803 100644
--- a/api/fields/hit_testing_fields.py
+++ b/api/fields/hit_testing_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import TimestampField
diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py
index e0b3e340f6..16dd26a10e 100644
--- a/api/fields/installed_app_fields.py
+++ b/api/fields/installed_app_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py
index 8007b7e052..08e38a6931 100644
--- a/api/fields/member_fields.py
+++ b/api/fields/member_fields.py
@@ -1,8 +1,17 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
-simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
+simple_account_fields = {
+ "id": fields.String,
+ "name": fields.String,
+ "email": fields.String,
+}
+
+
+def build_simple_account_model(api_or_ns: Api | Namespace):
+ return api_or_ns.model("SimpleAccount", simple_account_fields)
+
account_fields = {
"id": fields.String,
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index e6aebd810f..a419da2e18 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -1,11 +1,19 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from .raws import FilesContainedField
-feedback_fields = {"rating": fields.String}
+feedback_fields = {
+ "rating": fields.String,
+}
+
+
+def build_feedback_model(api_or_ns: Api | Namespace):
+ """Build the feedback model for the API or Namespace."""
+ return api_or_ns.model("Feedback", feedback_fields)
+
agent_thought_fields = {
"id": fields.String,
@@ -21,6 +29,12 @@ agent_thought_fields = {
"files": fields.List(fields.String),
}
+
+def build_agent_thought_model(api_or_ns: Api | Namespace):
+ """Build the agent thought model for the API or Namespace."""
+ return api_or_ns.model("AgentThought", agent_thought_fields)
+
+
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
diff --git a/api/fields/raws.py b/api/fields/raws.py
index 15ec16ab13..9bc6a12c78 100644
--- a/api/fields/raws.py
+++ b/api/fields/raws.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from core.file import File
diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py
index 4126c24598..2ff917d6bc 100644
--- a/api/fields/segment_fields.py
+++ b/api/fields/segment_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from libs.helper import TimestampField
diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py
index 9af4fc57dd..d5b7c86a04 100644
--- a/api/fields/tag_fields.py
+++ b/api/fields/tag_fields.py
@@ -1,3 +1,12 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
-tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}
+dataset_tag_fields = {
+ "id": fields.String,
+ "name": fields.String,
+ "type": fields.String,
+ "binding_count": fields.String,
+}
+
+
+def build_dataset_tag_fields(api_or_ns: Api | Namespace):
+ return api_or_ns.model("DataSetTag", dataset_tag_fields)
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index 823c99ec6b..243efd817c 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -1,8 +1,8 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
-from fields.end_user_fields import simple_end_user_fields
-from fields.member_fields import simple_account_fields
-from fields.workflow_run_fields import workflow_run_for_log_fields
+from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
+from fields.member_fields import build_simple_account_model, simple_account_fields
+from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields
from libs.helper import TimestampField
workflow_app_log_partial_fields = {
@@ -15,6 +15,24 @@ workflow_app_log_partial_fields = {
"created_at": TimestampField,
}
+
+def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
+ """Build the workflow app log partial model for the API or Namespace."""
+ workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
+ simple_account_model = build_simple_account_model(api_or_ns)
+ simple_end_user_model = build_simple_end_user_model(api_or_ns)
+
+ copied_fields = workflow_app_log_partial_fields.copy()
+ copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
+ copied_fields["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+ )
+ copied_fields["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+ )
+ return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
+
+
workflow_app_log_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer,
@@ -22,3 +40,13 @@ workflow_app_log_pagination_fields = {
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(workflow_app_log_partial_fields)),
}
+
+
+def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
+ """Build the workflow app log pagination model for the API or Namespace."""
+ # Build the nested partial model first
+ workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
+
+ copied_fields = workflow_app_log_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model))
+ return api_or_ns.model("WorkflowAppLogPagination", copied_fields)
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index 930e59cc1c..f048d0f3b6 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index a106728e9c..6462d8ce5a 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restx import Api, Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -17,6 +17,11 @@ workflow_run_for_log_fields = {
"exceptions_count": fields.Integer,
}
+
+def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
+ return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
+
+
workflow_run_for_list_fields = {
"id": fields.String,
"version": fields.String,
diff --git a/api/libs/external_api.py b/api/libs/external_api.py
index 2070df3e55..95d13cd0e6 100644
--- a/api/libs/external_api.py
+++ b/api/libs/external_api.py
@@ -1,119 +1,111 @@
import re
import sys
+from collections.abc import Mapping
from typing import Any
from flask import current_app, got_request_exception
-from flask_restful import Api, http_status_message
-from werkzeug.datastructures import Headers
+from flask_restx import Api
from werkzeug.exceptions import HTTPException
+from werkzeug.http import HTTP_STATUS_CODES
from core.errors.error import AppInvokeQuotaExceededError
-class ExternalApi(Api):
- def handle_error(self, e):
- """Error handler for the API transforms a raised exception into a Flask
- response, with the appropriate HTTP status code and body.
+def http_status_message(code):
+ return HTTP_STATUS_CODES.get(code, "")
- :param e: the raised Exception object
- :type e: Exception
- """
+def register_external_error_handlers(api: Api) -> None:
+ @api.errorhandler(HTTPException)
+ def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
- headers = Headers()
- if isinstance(e, HTTPException):
- if e.response is not None:
- resp = e.get_response()
- return resp
+ # If Werkzeug already prepared a Response, just use it.
+ if getattr(e, "response", None) is not None:
+ return e.response
- status_code = e.code
- default_data = {
- "code": re.sub(r"(?= 500:
- exc_info: Any = sys.exc_info()
- if exc_info[1] is None:
- exc_info = None
- current_app.log_exception(exc_info)
-
- if status_code == 406 and self.default_mediatype is None:
- # if we are handling NotAcceptable (406), make sure that
- # make_response uses a representation we support as the
- # default mediatype (so that make_response doesn't throw
- # another NotAcceptable error).
- supported_mediatypes = list(self.representations.keys()) # only supported application/json
- fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
- data = {"code": "not_acceptable", "message": data.get("message")}
- resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
+ # Payload per status
+ if status_code == 406 and api.default_mediatype is None:
+ data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code}
+ return data, status_code, headers
elif status_code == 400:
- if isinstance(data.get("message"), dict):
- param_key, param_value = list(data.get("message", {}).items())[0]
- data = {"code": "invalid_param", "message": param_value, "params": param_key}
+ msg = default_data["message"]
+ if isinstance(msg, Mapping) and msg:
+ # Convert param errors like {"field": "reason"} into a friendly shape
+ param_key, param_value = next(iter(msg.items()))
+ data = {
+ "code": "invalid_param",
+ "message": str(param_value),
+ "params": param_key,
+ "status": status_code,
+ }
else:
- if "code" not in data:
- data["code"] = "unknown"
-
- resp = self.make_response(data, status_code, headers)
+ data = {**default_data}
+ data.setdefault("code", "unknown")
+ return data, status_code, headers
else:
- if "code" not in data:
- data["code"] = "unknown"
+ data = {**default_data}
+ data.setdefault("code", "unknown")
+ # If you need WWW-Authenticate for 401, add it to headers
+ if status_code == 401:
+ headers["WWW-Authenticate"] = 'Bearer realm="api"'
+ return data, status_code, headers
- resp = self.make_response(data, status_code, headers)
+ @api.errorhandler(ValueError)
+ def handle_value_error(e: ValueError):
+ got_request_exception.send(current_app, exception=e)
+ status_code = 400
+ data = {"code": "invalid_param", "message": str(e), "status": status_code}
+ return data, status_code
- if status_code == 401:
- resp = self.unauthorized(resp)
- return resp
+ @api.errorhandler(AppInvokeQuotaExceededError)
+ def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
+ got_request_exception.send(current_app, exception=e)
+ status_code = 429
+ data = {"code": "too_many_requests", "message": str(e), "status": status_code}
+ return data, status_code
+
+ @api.errorhandler(Exception)
+ def handle_general_exception(e: Exception):
+ got_request_exception.send(current_app, exception=e)
+
+ status_code = 500
+ data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
+
+ # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
+ if not isinstance(data, Mapping):
+ data = {"message": str(e)}
+
+ data.setdefault("code", "unknown")
+ data.setdefault("status", status_code)
+
+ # Log stack
+ exc_info: Any = sys.exc_info()
+ if exc_info[1] is None:
+ exc_info = None
+ current_app.log_exception(exc_info)
+
+ return data, status_code
+
+
+class ExternalApi(Api):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ register_external_error_handlers(self)
diff --git a/api/libs/helper.py b/api/libs/helper.py
index b36f972e19..70986fedd3 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
-from flask_restful import fields
+from flask_restx import fields
from pydantic import BaseModel
from configs import dify_config
@@ -57,7 +57,7 @@ def run(script):
class AppIconUrlField(fields.Raw):
- def output(self, key, obj):
+ def output(self, key, obj, **kwargs):
if obj is None:
return None
@@ -72,7 +72,7 @@ class AppIconUrlField(fields.Raw):
class AvatarUrlField(fields.Raw):
- def output(self, key, obj):
+ def output(self, key, obj, **kwargs):
if obj is None:
return None
diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py
new file mode 100644
index 0000000000..616d072a1b
--- /dev/null
+++ b/api/libs/module_loading.py
@@ -0,0 +1,55 @@
+"""
+Module loading utilities similar to Django's module_loading.
+
+Reference implementation from Django:
+https://github.com/django/django/blob/main/django/utils/module_loading.py
+"""
+
+import sys
+from importlib import import_module
+from typing import Any
+
+
+def cached_import(module_path: str, class_name: str) -> Any:
+ """
+ Import a module and return the named attribute/class from it, with caching.
+
+ Args:
+ module_path: The module path to import from
+ class_name: The attribute/class name to retrieve
+
+ Returns:
+ The imported attribute/class
+ """
+ if not (
+ (module := sys.modules.get(module_path))
+ and (spec := getattr(module, "__spec__", None))
+ and getattr(spec, "_initializing", False) is False
+ ):
+ module = import_module(module_path)
+ return getattr(module, class_name)
+
+
+def import_string(dotted_path: str) -> Any:
+ """
+ Import a dotted module path and return the attribute/class designated by
+ the last name in the path. Raise ImportError if the import failed.
+
+ Args:
+ dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName')
+
+ Returns:
+ The imported class or attribute
+
+ Raises:
+ ImportError: If the module or attribute cannot be imported
+ """
+ try:
+ module_path, class_name = dotted_path.rsplit(".", 1)
+ except ValueError as err:
+ raise ImportError(f"{dotted_path} doesn't look like a module path") from err
+
+ try:
+ return cached_import(module_path, class_name)
+ except AttributeError as err:
+ raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err
diff --git a/api/models/task.py b/api/models/task.py
index ab700c553c..9a52fcfb41 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Optional
import sqlalchemy as sa
-from celery import states # type: ignore
+from celery import states
from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 7ff463e08f..2fea3fcd78 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
import sqlalchemy as sa
-from flask_login import current_user
from sqlalchemy import DateTime, orm
from core.file.constants import maybe_file_object
@@ -18,7 +17,6 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
-from libs.helper import extract_tenant_id
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@@ -351,8 +349,8 @@ class Workflow(Base):
if self._environment_variables is None:
self._environment_variables = "{}"
- # Get tenant_id from current_user (Account or EndUser)
- tenant_id = extract_tenant_id(current_user)
+ # Use workflow.tenant_id to avoid relying on request user in background threads
+ tenant_id = self.tenant_id
if not tenant_id:
return []
@@ -382,8 +380,8 @@ class Workflow(Base):
self._environment_variables = "{}"
return
- # Get tenant_id from current_user (Account or EndUser)
- tenant_id = extract_tenant_id(current_user)
+ # Use workflow.tenant_id to avoid relying on request user in background threads
+ tenant_id = self.tenant_id
if not tenant_id:
self._environment_variables = "{}"
diff --git a/api/mypy.ini b/api/mypy.ini
index 3a6a54afe1..44a01068e9 100644
--- a/api/mypy.ini
+++ b/api/mypy.ini
@@ -12,8 +12,11 @@ exclude = (?x)(
[mypy-flask_login]
ignore_missing_imports=True
-[mypy-flask_restful]
+[mypy-flask_restx]
ignore_missing_imports=True
-[mypy-flask_restful.inputs]
+[mypy-flask_restx.api]
+ignore_missing_imports=True
+
+[mypy-flask_restx.inputs]
ignore_missing_imports=True
diff --git a/api/pyproject.toml b/api/pyproject.toml
index ce642aa9c8..6aa4746d2f 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -13,13 +13,12 @@ dependencies = [
"cachetools~=5.3.0",
"celery~=5.5.2",
"chardet~=5.1.0",
- "flask~=3.1.0",
+ "flask~=3.1.2",
"flask-compress~=1.17",
"flask-cors~=6.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
"flask-orjson~=2.0.0",
- "flask-restful~=0.3.10",
"flask-sqlalchemy~=3.1.1",
"gevent~=24.11.1",
"gmpy2~=2.2.1",
@@ -88,6 +87,7 @@ dependencies = [
"sseclient-py>=1.8.0",
"httpx-sse>=0.4.0",
"sendgrid~=6.12.3",
+ "flask-restx>=1.3.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -110,7 +110,7 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=32.1.0",
"lxml-stubs~=0.5.1",
- "mypy~=1.16.0",
+ "mypy~=1.17.1",
"ruff~=0.12.3",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
@@ -164,6 +164,7 @@ dev = [
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"types-redis>=4.6.0.20241004",
+ "celery-types>=0.23.0",
]
############################################################
diff --git a/api/repositories/factory.py b/api/repositories/factory.py
index 1f0320054c..0be9c8908c 100644
--- a/api/repositories/factory.py
+++ b/api/repositories/factory.py
@@ -5,17 +5,14 @@ This factory is specifically designed for DifyAPI repositories that handle
service-layer operations with dependency injection patterns.
"""
-import logging
-
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
+from libs.module_loading import import_string
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
-logger = logging.getLogger(__name__)
-
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
"""
@@ -50,17 +47,9 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
try:
- repository_class = cls._import_class(class_path)
- cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
- # Service repository requires session_maker parameter
- cls._validate_constructor_signature(repository_class, ["session_maker"])
-
+ repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
- except RepositoryImportError:
- # Re-raise our custom errors as-is
- raise
- except Exception as e:
- logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
+ except (ImportError, Exception) as e:
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
) from e
@@ -87,15 +76,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
try:
- repository_class = cls._import_class(class_path)
- cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
- # Service repository requires session_maker parameter
- cls._validate_constructor_signature(repository_class, ["session_maker"])
-
+ repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
- except RepositoryImportError:
- # Re-raise our custom errors as-is
- raise
- except Exception as e:
- logger.exception("Failed to create APIWorkflowRunRepository")
+ except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 1cce8e67a4..0bb903fbbc 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -425,7 +425,7 @@ class AccountService:
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
@@ -452,12 +452,14 @@ class AccountService:
account: Optional[Account] = None,
email: Optional[str] = None,
old_email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
phase: Optional[str] = None,
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
+ if not phase:
+ raise ValueError("phase must be provided.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
@@ -480,7 +482,7 @@ class AccountService:
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
@@ -496,7 +498,7 @@ class AccountService:
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
workspace_name: Optional[str] = "",
):
account_email = account.email if account else email
@@ -509,6 +511,7 @@ class AccountService:
raise OwnerTransferRateLimitExceededError()
code, token = cls.generate_owner_transfer_token(account_email, account)
+ workspace_name = workspace_name or ""
send_owner_transfer_confirm_task.delay(
language=language,
@@ -524,13 +527,14 @@ class AccountService:
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
workspace_name: Optional[str] = "",
- new_owner_email: Optional[str] = "",
+ new_owner_email: str = "",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
+ workspace_name = workspace_name or ""
send_old_owner_transfer_notify_email_task.delay(
language=language,
@@ -544,12 +548,13 @@ class AccountService:
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
- language: Optional[str] = "en-US",
+ language: str = "en-US",
workspace_name: Optional[str] = "",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
+ workspace_name = workspace_name or ""
send_new_owner_transfer_notify_email_task.delay(
language=language,
@@ -633,7 +638,10 @@ class AccountService:
@classmethod
def send_email_code_login_email(
- cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: str = "en-US",
):
email = account.email if account else email
if email is None:
@@ -1260,10 +1268,11 @@ class RegisterService:
raise AccountAlreadyInTenantError("Account already in tenant.")
token = cls.generate_invite_token(tenant, account)
+ language = account.interface_language or "en-US"
# send email
send_invite_member_mail_task.delay(
- language=account.interface_language,
+ language=language,
to=email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 1a0fdfa420..45b246af1e 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -1,4 +1,3 @@
-import datetime
import uuid
from typing import cast
@@ -10,6 +9,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@@ -473,7 +473,7 @@ class AppAnnotationService:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args["score_threshold"]
annotation_setting.updated_user_id = current_user.id
- annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
db.session.commit()
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index 4f3dd3c762..ac603d3cc9 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -1,3 +1,5 @@
+import contextlib
+import logging
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
@@ -22,6 +24,9 @@ from services.errors.conversation import (
LastConversationNotExistsError,
)
from services.errors.message import MessageNotExistsError
+from tasks.delete_conversation_task import delete_conversation_related_data
+
+logger = logging.getLogger(__name__)
class ConversationService:
@@ -142,13 +147,11 @@ class ConversationService:
raise MessageNotExistsError()
# generate conversation name
- try:
+ with contextlib.suppress(Exception):
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, message.query, conversation.id, app_model.id
)
conversation.name = name
- except Exception:
- pass
db.session.commit()
@@ -176,11 +179,21 @@ class ConversationService:
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
- conversation = cls.get_conversation(app_model, conversation_id, user)
+ try:
+ logger.info(
+ "Initiating conversation deletion for app_name %s, conversation_id: %s",
+ app_model.name,
+ conversation_id,
+ )
- conversation.is_deleted = True
- conversation.updated_at = naive_utc_now()
- db.session.commit()
+ db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
+ db.session.commit()
+
+ delete_conversation_related_data.delay(conversation_id)
+
+ except Exception as e:
+ db.session.rollback()
+ raise e
@classmethod
def get_conversational_variable(
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 9fb048fac4..fc2cbba78b 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -1234,7 +1234,7 @@ class DocumentService:
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
- document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
@@ -1552,7 +1552,7 @@ class DocumentService:
document.parsing_completed_at = None
document.cleaning_completed_at = None
document.splitting_completed_at = None
- document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
db.session.add(document)
@@ -1912,7 +1912,7 @@ class DocumentService:
Returns:
dict: Update information or None if no update needed
"""
- now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ now = naive_utc_now()
if action == "enable":
return DocumentService._prepare_enable_update(document, now)
@@ -2040,8 +2040,8 @@ class SegmentService:
word_count=len(content),
tokens=tokens,
status="completed",
- indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
- completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ indexing_at=naive_utc_now(),
+ completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
@@ -2061,7 +2061,7 @@ class SegmentService:
except Exception as e:
logging.exception("create segment index failed")
segment_document.enabled = False
- segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
@@ -2117,8 +2117,8 @@ class SegmentService:
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
- indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
- completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ indexing_at=naive_utc_now(),
+ completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
@@ -2145,7 +2145,7 @@ class SegmentService:
logging.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
- segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
@@ -2162,7 +2162,7 @@ class SegmentService:
if segment.enabled != action:
if not action:
segment.enabled = action
- segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
db.session.commit()
@@ -2260,10 +2260,10 @@ class SegmentService:
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
- segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
- segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.indexing_at = naive_utc_now()
+ segment.completed_at = naive_utc_now()
segment.updated_by = current_user.id
- segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.updated_at = naive_utc_now()
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
@@ -2316,7 +2316,7 @@ class SegmentService:
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
- segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
@@ -2344,13 +2344,9 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
- # Check if segment_ids is not empty to avoid WHERE false condition
- if not segment_ids or len(segment_ids) == 0:
- return
- index_node_ids = (
- db.session.query(DocumentSegment)
- .with_entities(DocumentSegment.index_node_id)
- .where(
+ segments = (
+ db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
+ .filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2358,7 +2354,15 @@ class SegmentService:
)
.all()
)
- index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
+
+ if not segments:
+ return
+
+ index_node_ids = [seg.index_node_id for seg in segments]
+ total_words = sum(seg.word_count for seg in segments)
+
+ document.word_count -= total_words
+ db.session.add(document)
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
@@ -2418,7 +2422,7 @@ class SegmentService:
if cache_result is not None:
continue
segment.enabled = False
- segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
@@ -2508,7 +2512,7 @@ class SegmentService:
child_chunk.content = child_chunk_update_args.content
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
- child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
update_child_chunks.append(child_chunk)
else:
@@ -2565,7 +2569,7 @@ class SegmentService:
child_chunk.content = content
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
- child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
diff --git a/api/services/file_service.py b/api/services/file_service.py
index e234c2f325..4c0a0f451c 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -1,4 +1,3 @@
-import datetime
import hashlib
import os
import uuid
@@ -18,6 +17,7 @@ from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
+from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from models.account import Account
from models.enums import CreatorUserRole
@@ -80,7 +80,7 @@ class FileService:
mime_type=mimetype,
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
created_by=user.id,
- created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
used=False,
hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url,
@@ -131,10 +131,10 @@ class FileService:
mime_type="text/plain",
created_by=current_user.id,
created_by_role=CreatorUserRole.ACCOUNT,
- created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
used=True,
used_by=current_user.id,
- used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ used_at=naive_utc_now(),
)
db.session.add(upload_file)
diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py
index 2a83588f41..fd222f59d3 100644
--- a/api/services/metadata_service.py
+++ b/api/services/metadata_service.py
@@ -1,5 +1,4 @@
import copy
-import datetime
import logging
from typing import Optional
@@ -8,6 +7,7 @@ from flask_login import current_user
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
@@ -69,7 +69,7 @@ class MetadataService:
old_name = metadata.name
metadata.name = name
metadata.updated_by = current_user.id
- metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py
index fe28aa006e..f8dd70c790 100644
--- a/api/services/model_load_balancing_service.py
+++ b/api/services/model_load_balancing_service.py
@@ -1,4 +1,3 @@
-import datetime
import json
import logging
from json import JSONDecodeError
@@ -17,6 +16,7 @@ from core.model_runtime.entities.provider_entities import (
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.provider import LoadBalancingModelConfig
logger = logging.getLogger(__name__)
@@ -371,7 +371,7 @@ class ModelLoadBalancingService:
load_balancing_config.name = name
load_balancing_config.enabled = enabled
- load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ load_balancing_config.updated_at = naive_utc_now()
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py
index 59d5b50e23..f245dd7527 100644
--- a/api/services/tools/tools_manage_service.py
+++ b/api/services/tools/tools_manage_service.py
@@ -1,4 +1,5 @@
import logging
+from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
- def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
+ def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
"""
list tool providers
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index a9df8d0d73..8d21335c86 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -63,7 +63,7 @@ class WebAppAuthService:
@classmethod
def send_email_code_login_email(
- cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
+ cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
):
email = account.email if account else email
if email is None:
diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py
index afcf1f7621..00b02f8091 100644
--- a/api/services/workflow/workflow_converter.py
+++ b/api/services/workflow/workflow_converter.py
@@ -402,7 +402,7 @@ class WorkflowConverter:
)
role_prefix = None
- prompts: Any = None
+ prompts: Optional[Any] = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index b52f4924ba..9f01bcb668 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -1,5 +1,4 @@
import dataclasses
-import datetime
import logging
from collections.abc import Mapping, Sequence
from enum import StrEnum
@@ -23,6 +22,7 @@ from core.workflow.nodes.variable_assigner.common.helpers import get_updated_var
from core.workflow.variable_loader import VariableLoader
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
+from libs.datetime_utils import naive_utc_now
from models import App, Conversation
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
@@ -231,7 +231,7 @@ class WorkflowDraftVariableService:
variable.set_name(name)
if value is not None:
variable.set_value(value)
- variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ variable.last_edited_at = naive_utc_now()
self._session.flush()
return variable
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index c5ee4ce3f9..8834229e16 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -1,15 +1,15 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -95,7 +95,7 @@ def add_document_to_index_task(dataset_document_id: str):
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
- DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
@@ -107,7 +107,7 @@ def add_document_to_index_task(dataset_document_id: str):
except Exception as e:
logging.exception("add document to index failed")
dataset_document.enabled = False
- dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py
index e436f00133..5bf8e7c33e 100644
--- a/api/tasks/annotation/add_annotation_to_index_task.py
+++ b/api/tasks/annotation/add_annotation_to_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index 47dc3ee90e..fd33feea16 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py
index f016400e16..1894031a80 100644
--- a/api/tasks/annotation/delete_annotation_index_task.py
+++ b/api/tasks/annotation/delete_annotation_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index 0076113ce8..a8375dfa26 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index 44c65c0783..9ffaf81af6 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -1,14 +1,14 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
@@ -72,7 +72,7 @@ def enable_annotation_reply_task(
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
- annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py
index 5f11d5aa00..337434b768 100644
--- a/api/tasks/annotation/update_annotation_to_index_task.py
+++ b/api/tasks/annotation/update_annotation_to_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py
index e64a799146..ed47b62e1b 100644
--- a/api/tasks/batch_clean_document_task.py
+++ b/api/tasks/batch_clean_document_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index dee43cd854..50293f38a7 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -1,4 +1,3 @@
-import datetime
import logging
import tempfile
import time
@@ -7,7 +6,7 @@ from pathlib import Path
import click
import pandas as pd
-from celery import shared_task # type: ignore
+from celery import shared_task
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -17,6 +16,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.vector_service import VectorService
@@ -123,9 +123,9 @@ def batch_create_segment_to_index_task(
word_count=len(content),
tokens=tokens,
created_by=user_id,
- indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ indexing_at=naive_utc_now(),
status="completed",
- completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index 7b940847c9..3d3fadbd0a 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 5479ba8e8f..c18329a9c2 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -3,7 +3,7 @@ import time
from typing import Optional
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index bf1a92f038..3ad6257cda 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index 543a512851..db2f69596d 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -1,15 +1,15 @@
-import datetime
import logging
import time
from typing import Optional
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.indexing_at: naive_utc_now(),
}
)
db.session.commit()
@@ -79,7 +79,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
- DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
@@ -89,7 +89,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
except Exception as e:
logging.exception("create segment to index failed")
segment.enabled = False
- segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index 5ab377c232..512ea1048a 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -3,7 +3,7 @@ import time
from typing import Literal
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py
index ef50adf8d5..29f5a2450d 100644
--- a/api/tasks/delete_account_task.py
+++ b/api/tasks/delete_account_task.py
@@ -1,6 +1,6 @@
import logging
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_database import db
from models.account import Account
diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py
new file mode 100644
index 0000000000..4279dd2c17
--- /dev/null
+++ b/api/tasks/delete_conversation_task.py
@@ -0,0 +1,68 @@
+import logging
+import time
+
+import click
+from celery import shared_task # type: ignore
+
+from extensions.ext_database import db
+from models import ConversationVariable
+from models.model import Message, MessageAnnotation, MessageFeedback
+from models.tools import ToolConversationVariables, ToolFile
+from models.web import PinnedConversation
+
+
+@shared_task(queue="conversation")
+def delete_conversation_related_data(conversation_id: str) -> None:
+ """
+ Delete related data conversation in correct order from datatbase to respect foreign key constraints
+
+ Args:
+ conversation_id: conversation Id
+ """
+
+ logging.info(
+ click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green")
+ )
+ start_at = time.perf_counter()
+
+ try:
+ db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ db.session.query(ToolConversationVariables).where(
+ ToolConversationVariables.conversation_id == conversation_id
+ ).delete(synchronize_session=False)
+
+ db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ db.session.commit()
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+
+ except Exception as e:
+ logging.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
+ db.session.rollback()
+ raise e
+ finally:
+ db.session.close()
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index da12355d23..f091085fb8 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index fa4ec15f8a..c813a9dca6 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py
index f033f05084..252321ba83 100644
--- a/api/tasks/disable_segments_from_index_task.py
+++ b/api/tasks/disable_segments_from_index_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 993b2ac404..4afd13eb13 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -1,14 +1,14 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceOauthBinding
@@ -72,7 +72,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index 728db2e2dc..c414b01d0e 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 053c0c5f41..31bbc8b570 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -1,13 +1,13 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -31,7 +31,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
return
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index faa7e2b8d0..f3850b7e3b 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -1,14 +1,14 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -55,7 +55,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
@@ -86,7 +86,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
db.session.commit()
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index f801c9d9ee..a4bcc043e3 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -1,15 +1,15 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -89,7 +89,7 @@ def enable_segment_to_index_task(segment_id: str):
except Exception as e:
logging.exception("enable segment to index failed")
segment.enabled = False
- segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py
index 777380631f..1db984f0d3 100644
--- a/api/tasks/enable_segments_to_index_task.py
+++ b/api/tasks/enable_segments_to_index_task.py
@@ -1,15 +1,15 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -103,7 +103,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
{
"error": str(e),
"status": "error",
- "disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ "disabled_at": naive_utc_now(),
"enabled": False,
}
)
diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py
index 38b5ca1800..43ddbfc03b 100644
--- a/api/tasks/mail_account_deletion_task.py
+++ b/api/tasks/mail_account_deletion_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py
index 054053558d..a56109705a 100644
--- a/api/tasks/mail_change_mail_task.py
+++ b/api/tasks/mail_change_mail_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py
index a82ab55384..53ea3709cd 100644
--- a/api/tasks/mail_email_code_login.py
+++ b/api/tasks/mail_email_code_login.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
diff --git a/api/tasks/mail_inner_task.py b/api/tasks/mail_inner_task.py
index 101f7ebaa4..cad4657bc8 100644
--- a/api/tasks/mail_inner_task.py
+++ b/api/tasks/mail_inner_task.py
@@ -3,7 +3,7 @@ import time
from collections.abc import Mapping
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from flask import render_template_string
from extensions.ext_mail import mail
diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py
index ff351f08af..f4f7f58416 100644
--- a/api/tasks/mail_invite_member_task.py
+++ b/api/tasks/mail_invite_member_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py
index 3856bf294a..db7158e786 100644
--- a/api/tasks/mail_owner_transfer_task.py
+++ b/api/tasks/mail_owner_transfer_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py
index b01af7827b..066d648530 100644
--- a/api/tasks/mail_reset_password_task.py
+++ b/api/tasks/mail_reset_password_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py
index c7e0047664..a4ef60b13c 100644
--- a/api/tasks/ops_trace_task.py
+++ b/api/tasks/ops_trace_task.py
@@ -1,7 +1,7 @@
import json
import logging
-from celery import shared_task # type: ignore
+from celery import shared_task
from flask import current_app
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
index 9ea6aa6214..ec0b534546 100644
--- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
+++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
@@ -2,7 +2,7 @@ import traceback
import typing
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index ff489340cd..998fc6b32d 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -2,7 +2,7 @@ import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 828c52044f..3d623c09d1 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -4,7 +4,7 @@ from collections.abc import Callable
import click
import sqlalchemy as sa
-from celery import shared_task # type: ignore
+from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
@@ -370,8 +370,8 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs
query_sql = """
- SELECT id FROM workflow_draft_variables
- WHERE app_id = :app_id
+ SELECT id FROM workflow_draft_variables
+ WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
@@ -382,7 +382,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Delete the batch
delete_sql = """
- DELETE FROM workflow_draft_variables
+ DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index 524130a297..6356b1c46c 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -1,13 +1,13 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@@ -54,9 +54,9 @@ def remove_document_from_index_task(document_id: str):
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
- DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
- DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 26b41aff2e..67af857f40 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -1,14 +1,14 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -51,7 +51,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document:
document.indexing_status = "error"
document.error = str(e)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
@@ -79,7 +79,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.commit()
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
@@ -89,7 +89,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index f112a97d2f..ad782f9b88 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -1,14 +1,14 @@
-import datetime
import logging
import time
import click
-from celery import shared_task # type: ignore
+from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document:
document.indexing_status = "error"
document.error = str(e)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
@@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
db.session.commit()
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
@@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))
diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py
index 2f9fb628ca..77ddf83023 100644
--- a/api/tasks/workflow_execution_tasks.py
+++ b/api/tasks/workflow_execution_tasks.py
@@ -8,7 +8,7 @@ improving performance by offloading storage operations to background workers.
import json
import logging
-from celery import shared_task # type: ignore[import-untyped]
+from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py
index dfc8a33564..16356086cf 100644
--- a/api/tasks/workflow_node_execution_tasks.py
+++ b/api/tasks/workflow_node_execution_tasks.py
@@ -8,7 +8,7 @@ improving performance by offloading storage operations to background workers.
import json
import logging
-from celery import shared_task # type: ignore[import-untyped]
+from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
index 83f4d70ce9..2f0f38e0b8 100644
--- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
+++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
@@ -1,5 +1,5 @@
from flask import Flask, request
-from flask_restful import Api, Resource
+from flask_restx import Api, Resource
app = Flask(__name__)
api = Api(app)
diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
index 4af35a8bef..be5b4de5a2 100644
--- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
+++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
@@ -1,5 +1,6 @@
import os
from collections import UserDict
+from typing import Optional
from unittest.mock import MagicMock
import pytest
@@ -21,7 +22,7 @@ class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
- adapter: HTTPAdapter = None,
+ adapter: Optional[HTTPAdapter] = None,
):
self.conn = MagicMock()
self._config = MagicMock()
diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py
index ae5f9761b4..02f658aad6 100644
--- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py
+++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py
@@ -23,7 +23,7 @@ class MockTcvectordbClass:
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=10,
- adapter: HTTPAdapter = None,
+ adapter: Optional[HTTPAdapter] = None,
pool_size: int = 2,
proxies: Optional[dict] = None,
password: Optional[str] = None,
@@ -72,11 +72,11 @@ class MockTcvectordbClass:
shard: int,
replicas: int,
description: Optional[str] = None,
- index: Index = None,
- embedding: Embedding = None,
+ index: Optional[Index] = None,
+ embedding: Optional[Embedding] = None,
timeout: Optional[float] = None,
ttl_config: Optional[dict] = None,
- filter_index_config: FilterIndexConfig = None,
+ filter_index_config: Optional[FilterIndexConfig] = None,
indexes: Optional[list[IndexField]] = None,
) -> RPCCollection:
return RPCCollection(
@@ -113,7 +113,7 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
vectors: list[list[float]],
- filter: Filter = None,
+ filter: Optional[Filter] = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
@@ -128,7 +128,7 @@ class MockTcvectordbClass:
collection_name: str,
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
- filter: Union[Filter, str] = None,
+ filter: Optional[Union[Filter, str]] = None,
rerank: Optional[Rerank] = None,
retrieve_vector: Optional[bool] = None,
output_fields: Optional[list[str]] = None,
@@ -158,7 +158,7 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
document_ids: Optional[list[str]] = None,
- filter: Filter = None,
+ filter: Optional[Filter] = None,
timeout: Optional[float] = None,
):
return {"code": 0, "msg": "operation success"}
diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py
index 8b57132772..21de8be6e3 100644
--- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py
+++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py
@@ -1,3 +1,4 @@
+import contextlib
import os
import pytest
@@ -44,10 +45,8 @@ class TestClickzettaVector(AbstractVectorTest):
yield vector
# Cleanup: delete the test collection
- try:
+ with contextlib.suppress(Exception):
vector.delete()
- except Exception:
- pass
def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""
diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py
index 3d7be0df7d..415e65ce51 100644
--- a/api/tests/test_containers_integration_tests/services/test_account_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_account_service.py
@@ -1639,7 +1639,7 @@ class TestTenantService:
email = fake.email()
name = fake.name()
password = fake.password(length=12)
- invalid_action = fake.word()
+ invalid_action = "invalid_action_that_doesnt_exist"
# Setup mocks
mock_external_service_dependencies[
"feature_service"
diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py
index 8816698af8..92d93d601e 100644
--- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py
@@ -410,18 +410,18 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotations with specific keywords
- unique_keyword = fake.word()
+ unique_keyword = f"unique_{fake.uuid4()[:8]}"
annotation_args = {
"question": f"Question with {unique_keyword} keyword",
"answer": f"Answer with {unique_keyword} keyword",
}
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
-
# Create another annotation without the keyword
other_args = {
- "question": "Question without keyword",
- "answer": "Answer without keyword",
+ "question": "Different question without special term",
+ "answer": "Different answer without special content",
}
+
AppAnnotationService.insert_app_annotation_directly(other_args, app.id)
# Search with keyword
diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py
new file mode 100644
index 0000000000..2d5cdf426d
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py
@@ -0,0 +1,1192 @@
+from unittest.mock import patch
+
+import pytest
+from faker import Faker
+from werkzeug.exceptions import NotFound
+
+from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
+from models.dataset import Dataset
+from models.model import App, Tag, TagBinding
+from services.tag_service import TagService
+
+
+class TestTagService:
+ """Integration tests for TagService using testcontainers."""
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.tag_service.current_user") as mock_current_user,
+ ):
+ # Setup default mock returns
+ mock_current_user.current_tenant_id = "test-tenant-id"
+ mock_current_user.id = "test-user-id"
+
+ yield {
+ "current_user": mock_current_user,
+ }
+
+ def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Helper method to create a test account and tenant for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+
+ Returns:
+ tuple: (account, tenant) - Created account and tenant instances
+ """
+ fake = Faker()
+
+ # Create account
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status="active",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Create tenant for the account
+ tenant = Tenant(
+ name=fake.company(),
+ status="normal",
+ )
+ db.session.add(tenant)
+ db.session.commit()
+
+ # Create tenant-account join
+ join = TenantAccountJoin(
+ tenant_id=tenant.id,
+ account_id=account.id,
+ role=TenantAccountRole.OWNER.value,
+ current=True,
+ )
+ db.session.add(join)
+ db.session.commit()
+
+ # Set current tenant for account
+ account.current_tenant = tenant
+
+ # Update mock to use real tenant ID
+ mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
+ mock_external_service_dependencies["current_user"].id = account.id
+
+ return account, tenant
+
+ def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
+ """
+ Helper method to create a test dataset for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ tenant_id: Tenant ID for the dataset
+
+ Returns:
+ Dataset: Created dataset instance
+ """
+ fake = Faker()
+
+ dataset = Dataset(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ provider="vendor",
+ permission="only_me",
+ data_source_type="upload",
+ indexing_technique="high_quality",
+ tenant_id=tenant_id,
+ created_by=mock_external_service_dependencies["current_user"].id,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(dataset)
+ db.session.commit()
+
+ return dataset
+
+ def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
+ """
+ Helper method to create a test app for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ tenant_id: Tenant ID for the app
+
+ Returns:
+ App: Created app instance
+ """
+ fake = Faker()
+
+ app = App(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ mode="chat",
+ icon_type="emoji",
+ icon="🤖",
+ icon_background="#FF6B6B",
+ enable_site=False,
+ enable_api=False,
+ tenant_id=tenant_id,
+ created_by=mock_external_service_dependencies["current_user"].id,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(app)
+ db.session.commit()
+
+ return app
+
+ def _create_test_tags(
+ self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3
+ ):
+ """
+ Helper method to create test tags for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ tenant_id: Tenant ID for the tags
+ tag_type: Type of tags to create
+ count: Number of tags to create
+
+ Returns:
+ list: List of created tag instances
+ """
+ fake = Faker()
+ tags = []
+
+ for i in range(count):
+ tag = Tag(
+ name=f"tag_{tag_type}_{i}_{fake.word()}",
+ type=tag_type,
+ tenant_id=tenant_id,
+ created_by=mock_external_service_dependencies["current_user"].id,
+ )
+ tags.append(tag)
+
+ from extensions.ext_database import db
+
+ for tag in tags:
+ db.session.add(tag)
+ db.session.commit()
+
+ return tags
+
+ def _create_test_tag_bindings(
+ self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id
+ ):
+ """
+ Helper method to create test tag bindings for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ tags: List of tags to bind
+ target_id: Target ID to bind tags to
+ tenant_id: Tenant ID for the bindings
+
+ Returns:
+ list: List of created tag binding instances
+ """
+ tag_bindings = []
+
+ for tag in tags:
+ tag_binding = TagBinding(
+ tag_id=tag.id,
+ target_id=target_id,
+ tenant_id=tenant_id,
+ created_by=mock_external_service_dependencies["current_user"].id,
+ )
+ tag_bindings.append(tag_binding)
+
+ from extensions.ext_database import db
+
+ for tag_binding in tag_bindings:
+ db.session.add(tag_binding)
+ db.session.commit()
+
+ return tag_bindings
+
+ def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of tags with binding count.
+
+ This test verifies:
+ - Proper tag retrieval with binding count
+ - Correct filtering by tag type and tenant
+ - Proper ordering by creation date
+ - Binding count calculation
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3
+ )
+
+ # Create dataset and bind tags
+ dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, tags[:2], dataset.id, tenant.id
+ )
+
+ # Act: Execute the method under test
+ result = TagService.get_tags("knowledge", tenant.id)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 3
+
+ # Verify tag data structure
+ for tag_result in result:
+ assert hasattr(tag_result, "id")
+ assert hasattr(tag_result, "type")
+ assert hasattr(tag_result, "name")
+ assert hasattr(tag_result, "binding_count")
+ assert tag_result.type == "knowledge"
+
+ # Verify binding count
+ tag_with_bindings = next((t for t in result if t.binding_count > 0), None)
+ assert tag_with_bindings is not None
+ assert tag_with_bindings.binding_count >= 1
+
+ # Verify ordering (newest first) - note: created_at is not in SELECT but used in ORDER BY
+ # The ordering is handled by the database, we just verify the results are returned
+ assert len(result) == 3
+
+ def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag retrieval with keyword filtering.
+
+ This test verifies:
+ - Proper keyword filtering functionality
+ - Case-insensitive search
+ - Partial match functionality
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags with specific names
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3
+ )
+
+ # Update tag names to make them searchable
+ from extensions.ext_database import db
+
+ tags[0].name = "python_development"
+ tags[1].name = "machine_learning"
+ tags[2].name = "web_development"
+ db.session.commit()
+
+ # Act: Execute the method under test with keyword filter
+ result = TagService.get_tags("app", tenant.id, keyword="development")
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 2 # Should find python_development and web_development
+
+ # Verify filtered results contain the keyword
+ for tag_result in result:
+ assert "development" in tag_result.name.lower()
+
+ # Verify no results for non-matching keyword
+ result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
+ assert len(result_no_match) == 0
+
+ def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag retrieval when no tags exist.
+
+ This test verifies:
+ - Proper handling of empty tag sets
+ - Correct return value for no results
+ """
+ # Arrange: Create test data without tags
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute the method under test
+ result = TagService.get_tags("knowledge", tenant.id)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 0
+ assert isinstance(result, list)
+
+ def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of target IDs by tag IDs.
+
+ This test verifies:
+ - Proper target ID retrieval for valid tag IDs
+ - Correct filtering by tag type and tenant
+ - Proper handling of tag bindings
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3
+ )
+
+ # Create multiple datasets and bind tags
+ datasets = []
+ for i in range(2):
+ dataset = self._create_test_dataset(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id
+ )
+ datasets.append(dataset)
+ # Bind first two tags to first dataset, last tag to second dataset
+ tags_to_bind = tags[:2] if i == 0 else tags[2:]
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, tags_to_bind, dataset.id, tenant.id
+ )
+
+ # Act: Execute the method under test
+ tag_ids = [tag.id for tag in tags]
+ result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 3 # Should find 3 target IDs (2 from first dataset, 1 from second)
+
+ # Verify all dataset IDs are returned
+ dataset_ids = [dataset.id for dataset in datasets]
+ for target_id in result:
+ assert target_id in dataset_ids
+
+ # Verify the first dataset appears twice (for the first two tags)
+ first_dataset_count = result.count(datasets[0].id)
+ assert first_dataset_count == 2
+
+ # Verify the second dataset appears once (for the last tag)
+ second_dataset_count = result.count(datasets[1].id)
+ assert second_dataset_count == 1
+
+ def test_get_target_ids_by_tag_ids_empty_tag_ids(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test target ID retrieval with empty tag IDs list.
+
+ This test verifies:
+ - Proper handling of empty tag IDs
+ - Correct return value for empty input
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute the method under test with empty tag IDs
+ result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, [])
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 0
+ assert isinstance(result, list)
+
+ def test_get_target_ids_by_tag_ids_no_matching_tags(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test target ID retrieval when no tags match the criteria.
+
+ This test verifies:
+ - Proper handling of non-existent tag IDs
+ - Correct return value for no matches
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent tag IDs
+ import uuid
+
+ non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
+
+ # Act: Execute the method under test
+ result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 0
+ assert isinstance(result, list)
+
+ def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of tags by tag name.
+
+ This test verifies:
+ - Proper tag retrieval by name
+ - Correct filtering by tag type and tenant
+ - Proper return value structure
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags with specific names
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 2
+ )
+
+ # Update tag names to make them searchable
+ from extensions.ext_database import db
+
+ tags[0].name = "python_tag"
+ tags[1].name = "ml_tag"
+ db.session.commit()
+
+ # Act: Execute the method under test
+ result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 1
+ assert result[0].name == "python_tag"
+ assert result[0].type == "app"
+ assert result[0].tenant_id == tenant.id
+
+ def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag retrieval by name when no matches exist.
+
+ This test verifies:
+ - Proper handling of non-existent tag names
+ - Correct return value for no matches
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute the method under test with non-existent tag name
+ result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag")
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 0
+ assert isinstance(result, list)
+
+ def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag retrieval by name with empty parameters.
+
+ This test verifies:
+ - Proper handling of empty tag type
+ - Proper handling of empty tag name
+ - Correct return value for invalid input
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute the method under test with empty parameters
+ result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag")
+ result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "")
+
+ # Assert: Verify the expected outcomes
+ assert result_empty_type is not None
+ assert len(result_empty_type) == 0
+ assert result_empty_name is not None
+ assert len(result_empty_name) == 0
+
+ def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of tags by target ID.
+
+ This test verifies:
+ - Proper tag retrieval for a specific target
+ - Correct filtering by tag type and tenant
+ - Proper join with tag bindings
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3
+ )
+
+ # Create app and bind tags
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, tags, app.id, tenant.id
+ )
+
+ # Act: Execute the method under test
+ result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 3
+
+ # Verify all tags are returned
+ for tag in result:
+ assert tag.type == "app"
+ assert tag.tenant_id == tenant.id
+ assert tag.id in [t.id for t in tags]
+
+ def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag retrieval by target ID when no tags are bound.
+
+ This test verifies:
+ - Proper handling of targets with no tag bindings
+ - Correct return value for no results
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create app without binding any tags
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Act: Execute the method under test
+ result = TagService.get_tags_by_target_id("app", tenant.id, app.id)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert len(result) == 0
+ assert isinstance(result, list)
+
+ def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful tag creation.
+
+ This test verifies:
+ - Proper tag creation with all required fields
+ - Correct database state after creation
+ - Proper UUID generation
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ tag_args = {"name": "test_tag_name", "type": "knowledge"}
+
+ # Act: Execute the method under test
+ result = TagService.save_tags(tag_args)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert result.name == "test_tag_name"
+ assert result.type == "knowledge"
+ assert result.tenant_id == tenant.id
+ assert result.created_by == account.id
+ assert result.id is not None
+
+ # Verify database state
+ from extensions.ext_database import db
+
+ db.session.refresh(result)
+ assert result.id is not None
+
+ # Verify tag was actually saved to database
+ saved_tag = db.session.query(Tag).where(Tag.id == result.id).first()
+ assert saved_tag is not None
+ assert saved_tag.name == "test_tag_name"
+
+ def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag creation with duplicate name.
+
+ This test verifies:
+ - Proper error handling for duplicate tag names
+ - Correct exception type and message
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create first tag
+ tag_args = {"name": "duplicate_tag", "type": "app"}
+ TagService.save_tags(tag_args)
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ TagService.save_tags(tag_args)
+ assert "Tag name already exists" in str(exc_info.value)
+
+ def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful tag update.
+
+ This test verifies:
+ - Proper tag update with new name
+ - Correct database state after update
+ - Proper error handling for non-existent tags
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create a tag to update
+ tag_args = {"name": "original_name", "type": "knowledge"}
+ tag = TagService.save_tags(tag_args)
+
+ # Update args
+ update_args = {"name": "updated_name", "type": "knowledge"}
+
+ # Act: Execute the method under test
+ result = TagService.update_tags(update_args, tag.id)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert result.name == "updated_name"
+ assert result.type == "knowledge"
+ assert result.id == tag.id
+
+ # Verify database state
+ from extensions.ext_database import db
+
+ db.session.refresh(result)
+ assert result.name == "updated_name"
+
+ # Verify tag was actually updated in database
+ updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first()
+ assert updated_tag is not None
+ assert updated_tag.name == "updated_name"
+
+ def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag update for non-existent tag.
+
+ This test verifies:
+ - Proper error handling for non-existent tags
+ - Correct exception type
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent tag ID
+ import uuid
+
+ non_existent_tag_id = str(uuid.uuid4())
+
+ update_args = {"name": "updated_name", "type": "knowledge"}
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ TagService.update_tags(update_args, non_existent_tag_id)
+ assert "Tag not found" in str(exc_info.value)
+
+ def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag update with duplicate name.
+
+ This test verifies:
+ - Proper error handling for duplicate tag names during update
+ - Correct exception type and message
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create two tags
+ tag1_args = {"name": "first_tag", "type": "app"}
+ tag1 = TagService.save_tags(tag1_args)
+
+ tag2_args = {"name": "second_tag", "type": "app"}
+ tag2 = TagService.save_tags(tag2_args)
+
+ # Try to update second tag with first tag's name
+ update_args = {"name": "first_tag", "type": "app"}
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ TagService.update_tags(update_args, tag2.id)
+ assert "Tag name already exists" in str(exc_info.value)
+
+ def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of tag binding count.
+
+ This test verifies:
+ - Proper binding count calculation
+ - Correct handling of tags with no bindings
+ - Proper database query execution
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2
+ )
+
+ # Create dataset and bind first tag
+ dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, [tags[0]], dataset.id, tenant.id
+ )
+
+ # Act: Execute the method under test
+ result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id)
+ result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id)
+
+ # Assert: Verify the expected outcomes
+ assert result_tag_with_bindings == 1
+ assert result_tag_without_bindings == 0
+
+ def test_get_tag_binding_count_non_existent_tag(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test binding count retrieval for non-existent tag.
+
+ This test verifies:
+ - Proper handling of non-existent tag IDs
+ - Correct return value for non-existent tags
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent tag ID
+ import uuid
+
+ non_existent_tag_id = str(uuid.uuid4())
+
+ # Act: Execute the method under test
+ result = TagService.get_tag_binding_count(non_existent_tag_id)
+
+ # Assert: Verify the expected outcomes
+ assert result == 0
+
+ def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful tag deletion.
+
+ This test verifies:
+ - Proper tag deletion from database
+ - Proper cleanup of associated tag bindings
+ - Correct database state after deletion
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tag with bindings
+ tag = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1
+ )[0]
+
+ # Create app and bind tag
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, [tag], app.id, tenant.id
+ )
+
+ # Verify tag and binding exist before deletion
+ from extensions.ext_database import db
+
+ tag_before = db.session.query(Tag).where(Tag.id == tag.id).first()
+ assert tag_before is not None
+
+ binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
+ assert binding_before is not None
+
+ # Act: Execute the method under test
+ TagService.delete_tag(tag.id)
+
+ # Assert: Verify the expected outcomes
+ # Verify tag was deleted
+ tag_after = db.session.query(Tag).where(Tag.id == tag.id).first()
+ assert tag_after is None
+
+ # Verify tag binding was deleted
+ binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
+ assert binding_after is None
+
+ def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag deletion for non-existent tag.
+
+ This test verifies:
+ - Proper error handling for non-existent tags
+ - Correct exception type
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent tag ID
+ import uuid
+
+ non_existent_tag_id = str(uuid.uuid4())
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ TagService.delete_tag(non_existent_tag_id)
+ assert "Tag not found" in str(exc_info.value)
+
+ def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful tag binding creation.
+
+ This test verifies:
+ - Proper tag binding creation
+ - Correct handling of duplicate bindings
+ - Proper database state after creation
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tags
+ tags = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2
+ )
+
+ # Create dataset
+ dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Act: Execute the method under test
+ binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]}
+ TagService.save_tag_binding(binding_args)
+
+ # Assert: Verify the expected outcomes
+ from extensions.ext_database import db
+
+ # Verify tag bindings were created
+ for tag in tags:
+ binding = (
+ db.session.query(TagBinding)
+ .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
+ .first()
+ )
+ assert binding is not None
+ assert binding.tenant_id == tenant.id
+ assert binding.created_by == account.id
+
+ def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag binding creation with duplicate bindings.
+
+ This test verifies:
+ - Proper handling of duplicate tag bindings
+ - No errors when trying to create existing bindings
+ - Correct database state after operation
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tag
+ tag = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1
+ )[0]
+
+ # Create app
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Create first binding
+ binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]}
+ TagService.save_tag_binding(binding_args)
+
+ # Act: Try to create duplicate binding
+ TagService.save_tag_binding(binding_args)
+
+ # Assert: Verify the expected outcomes
+ from extensions.ext_database import db
+
+ # Verify only one binding exists
+ bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
+ assert len(bindings) == 1
+
+ def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test tag binding creation with invalid target type.
+
+ This test verifies:
+ - Proper error handling for invalid target types
+ - Correct exception type
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tag
+ tag = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1
+ )[0]
+
+ # Create non-existent target ID
+ import uuid
+
+ non_existent_target_id = str(uuid.uuid4())
+
+ # Act & Assert: Verify proper error handling
+ binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]}
+
+ with pytest.raises(NotFound) as exc_info:
+ TagService.save_tag_binding(binding_args)
+ assert "Invalid binding type" in str(exc_info.value)
+
+ def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful tag binding deletion.
+
+ This test verifies:
+ - Proper tag binding deletion from database
+ - Correct database state after deletion
+ - Proper error handling for non-existent bindings
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tag
+ tag = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1
+ )[0]
+
+ # Create dataset and bind tag
+ dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+ self._create_test_tag_bindings(
+ db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id
+ )
+
+ # Verify binding exists before deletion
+ from extensions.ext_database import db
+
+ binding_before = (
+ db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
+ )
+ assert binding_before is not None
+
+ # Act: Execute the method under test
+ delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id}
+ TagService.delete_tag_binding(delete_args)
+
+ # Assert: Verify the expected outcomes
+ # Verify tag binding was deleted
+ binding_after = (
+ db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
+ )
+ assert binding_after is None
+
+ def test_delete_tag_binding_non_existent_binding(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test tag binding deletion for non-existent binding.
+
+ This test verifies:
+ - Proper handling of non-existent tag bindings
+ - No errors when trying to delete non-existent bindings
+ - Correct database state after operation
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create tag and dataset without binding
+ tag = self._create_test_tags(
+ db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1
+ )[0]
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Act: Try to delete non-existent binding
+ delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id}
+ TagService.delete_tag_binding(delete_args)
+
+ # Assert: Verify the expected outcomes
+ # No error should be raised, and database state should remain unchanged
+ from extensions.ext_database import db
+
+ bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
+ assert len(bindings) == 0
+
+ def test_check_target_exists_knowledge_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful target existence check for knowledge type.
+
+ This test verifies:
+ - Proper validation of knowledge dataset existence
+ - Correct error handling for non-existent datasets
+ - Proper tenant filtering
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create dataset
+ dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Act: Execute the method under test
+ TagService.check_target_exists("knowledge", dataset.id)
+
+ # Assert: Verify the expected outcomes
+ # No exception should be raised for existing dataset
+
+ def test_check_target_exists_knowledge_not_found(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test target existence check for non-existent knowledge dataset.
+
+ This test verifies:
+ - Proper error handling for non-existent knowledge datasets
+ - Correct exception type and message
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent dataset ID
+ import uuid
+
+ non_existent_dataset_id = str(uuid.uuid4())
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ TagService.check_target_exists("knowledge", non_existent_dataset_id)
+ assert "Dataset not found" in str(exc_info.value)
+
+ def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful target existence check for app type.
+
+ This test verifies:
+ - Proper validation of app existence
+ - Correct error handling for non-existent apps
+ - Proper tenant filtering
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create app
+ app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
+
+ # Act: Execute the method under test
+ TagService.check_target_exists("app", app.id)
+
+ # Assert: Verify the expected outcomes
+ # No exception should be raised for existing app
+
+ def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test target existence check for non-existent app.
+
+ This test verifies:
+ - Proper error handling for non-existent apps
+ - Correct exception type and message
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent app ID
+ import uuid
+
+ non_existent_app_id = str(uuid.uuid4())
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ TagService.check_target_exists("app", non_existent_app_id)
+ assert "App not found" in str(exc_info.value)
+
+ def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test target existence check for invalid type.
+
+ This test verifies:
+ - Proper error handling for invalid target types
+ - Correct exception type and message
+ """
+ # Arrange: Create test data
+ fake = Faker()
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Create non-existent target ID
+ import uuid
+
+ non_existent_target_id = str(uuid.uuid4())
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ TagService.check_target_exists("invalid_type", non_existent_target_id)
+ assert "Invalid binding type" in str(exc_info.value)
diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py
new file mode 100644
index 0000000000..6d6f1dab72
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py
@@ -0,0 +1,574 @@
+from unittest.mock import patch
+
+import pytest
+from faker import Faker
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from models.account import Account
+from models.model import Conversation, EndUser
+from models.web import PinnedConversation
+from services.account_service import AccountService, TenantService
+from services.app_service import AppService
+from services.web_conversation_service import WebConversationService
+
+
+class TestWebConversationService:
+ """Integration tests for WebConversationService using testcontainers."""
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.app_service.FeatureService") as mock_feature_service,
+ patch("services.app_service.EnterpriseService") as mock_enterprise_service,
+ patch("services.app_service.ModelManager") as mock_model_manager,
+ patch("services.account_service.FeatureService") as mock_account_feature_service,
+ ):
+ # Setup default mock returns for app service
+ mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
+ mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
+ mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
+
+ # Setup default mock returns for account service
+ mock_account_feature_service.get_system_features.return_value.is_allow_register = True
+
+ # Mock ModelManager for model configuration
+ mock_model_instance = mock_model_manager.return_value
+ mock_model_instance.get_default_model_instance.return_value = None
+ mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
+
+ yield {
+ "feature_service": mock_feature_service,
+ "enterprise_service": mock_enterprise_service,
+ "model_manager": mock_model_manager,
+ "account_feature_service": mock_account_feature_service,
+ }
+
+ def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Helper method to create a test app and account for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+
+ Returns:
+ tuple: (app, account) - Created app and account instances
+ """
+ fake = Faker()
+
+ # Setup mocks for account creation
+ mock_external_service_dependencies[
+ "account_feature_service"
+ ].get_system_features.return_value.is_allow_register = True
+
+ # Create account and tenant
+ account = AccountService.create_account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ password=fake.password(length=12),
+ )
+ TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
+ tenant = account.current_tenant
+
+ # Create app with realistic data
+ app_args = {
+ "name": fake.company(),
+ "description": fake.text(max_nb_chars=100),
+ "mode": "chat",
+ "icon_type": "emoji",
+ "icon": "🤖",
+ "icon_background": "#FF6B6B",
+ "api_rph": 100,
+ "api_rpm": 10,
+ }
+
+ app_service = AppService()
+ app = app_service.create_app(tenant.id, app_args, account)
+
+ return app, account
+
+ def _create_test_end_user(self, db_session_with_containers, app):
+ """
+ Helper method to create a test end user for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ app: App instance
+
+ Returns:
+ EndUser: Created end user instance
+ """
+ fake = Faker()
+
+ end_user = EndUser(
+ session_id=fake.uuid4(),
+ app_id=app.id,
+ type="normal",
+ is_anonymous=False,
+ tenant_id=app.tenant_id,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(end_user)
+ db.session.commit()
+
+ return end_user
+
+ def _create_test_conversation(self, db_session_with_containers, app, user, fake):
+ """
+ Helper method to create a test conversation for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ app: App instance
+ user: User instance (Account or EndUser)
+ fake: Faker instance
+
+ Returns:
+ Conversation: Created conversation instance
+ """
+ conversation = Conversation(
+ app_id=app.id,
+ app_model_config_id=app.app_model_config_id,
+ model_provider="openai",
+ model_id="gpt-3.5-turbo",
+ mode="chat",
+ name=fake.sentence(nb_words=3),
+ summary=fake.text(max_nb_chars=100),
+ inputs={},
+ introduction=fake.text(max_nb_chars=200),
+ system_instruction=fake.text(max_nb_chars=300),
+ system_instruction_tokens=50,
+ status="normal",
+ invoke_from=InvokeFrom.WEB_APP.value,
+ from_source="console" if isinstance(user, Account) else "api",
+ from_end_user_id=user.id if isinstance(user, EndUser) else None,
+ from_account_id=user.id if isinstance(user, Account) else None,
+ dialogue_count=0,
+ is_deleted=False,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(conversation)
+ db.session.commit()
+
+ return conversation
+
+ def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful pagination by last ID with basic parameters.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create multiple conversations
+ conversations = []
+ for i in range(5):
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+ conversations.append(conversation)
+
+ # Test pagination without pinned filter
+ result = WebConversationService.pagination_by_last_id(
+ session=db_session_with_containers,
+ app_model=app,
+ user=account,
+ last_id=None,
+ limit=3,
+ invoke_from=InvokeFrom.WEB_APP,
+ pinned=None,
+ sort_by="-updated_at",
+ )
+
+ # Verify results
+ assert result.limit == 3
+ assert len(result.data) == 3
+ assert result.has_more is True
+
+ # Verify conversations are in descending order by updated_at
+ assert result.data[0].updated_at >= result.data[1].updated_at
+ assert result.data[1].updated_at >= result.data[2].updated_at
+
+ def test_pagination_by_last_id_with_pinned_filter(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test pagination by last ID with pinned conversation filter.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create conversations
+ conversations = []
+ for i in range(5):
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+ conversations.append(conversation)
+
+ # Pin some conversations
+ pinned_conversation1 = PinnedConversation(
+ app_id=app.id,
+ conversation_id=conversations[0].id,
+ created_by_role="account",
+ created_by=account.id,
+ )
+ pinned_conversation2 = PinnedConversation(
+ app_id=app.id,
+ conversation_id=conversations[2].id,
+ created_by_role="account",
+ created_by=account.id,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(pinned_conversation1)
+ db.session.add(pinned_conversation2)
+ db.session.commit()
+
+ # Test pagination with pinned filter
+ result = WebConversationService.pagination_by_last_id(
+ session=db_session_with_containers,
+ app_model=app,
+ user=account,
+ last_id=None,
+ limit=10,
+ invoke_from=InvokeFrom.WEB_APP,
+ pinned=True,
+ sort_by="-updated_at",
+ )
+
+ # Verify only pinned conversations are returned
+ assert result.limit == 10
+ assert len(result.data) == 2
+ assert result.has_more is False
+
+ # Verify the returned conversations are the pinned ones
+ returned_ids = [conv.id for conv in result.data]
+ expected_ids = [conversations[0].id, conversations[2].id]
+ assert set(returned_ids) == set(expected_ids)
+
+ def test_pagination_by_last_id_with_unpinned_filter(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test pagination by last ID with unpinned conversation filter.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create conversations
+ conversations = []
+ for i in range(5):
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+ conversations.append(conversation)
+
+ # Pin one conversation
+ pinned_conversation = PinnedConversation(
+ app_id=app.id,
+ conversation_id=conversations[0].id,
+ created_by_role="account",
+ created_by=account.id,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(pinned_conversation)
+ db.session.commit()
+
+ # Test pagination with unpinned filter
+ result = WebConversationService.pagination_by_last_id(
+ session=db_session_with_containers,
+ app_model=app,
+ user=account,
+ last_id=None,
+ limit=10,
+ invoke_from=InvokeFrom.WEB_APP,
+ pinned=False,
+ sort_by="-updated_at",
+ )
+
+ # Verify unpinned conversations are returned (should be 4 out of 5)
+ assert result.limit == 10
+ assert len(result.data) == 4
+ assert result.has_more is False
+
+ # Verify the pinned conversation is not in the results
+ returned_ids = [conv.id for conv in result.data]
+ assert conversations[0].id not in returned_ids
+
+ # Verify all other conversations are in the results
+ expected_unpinned_ids = [conv.id for conv in conversations[1:]]
+ assert set(returned_ids) == set(expected_unpinned_ids)
+
+ def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful pinning of a conversation.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Pin the conversation
+ WebConversationService.pin(app, conversation.id, account)
+
+ # Verify the conversation was pinned
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is not None
+ assert pinned_conversation.app_id == app.id
+ assert pinned_conversation.conversation_id == conversation.id
+ assert pinned_conversation.created_by_role == "account"
+ assert pinned_conversation.created_by == account.id
+
+ def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test pinning a conversation that is already pinned (should not create duplicate).
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Pin the conversation first time
+ WebConversationService.pin(app, conversation.id, account)
+
+ # Pin the conversation again
+ WebConversationService.pin(app, conversation.id, account)
+
+ # Verify only one pinned conversation record exists
+ from extensions.ext_database import db
+
+ pinned_conversations = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .all()
+ )
+
+ assert len(pinned_conversations) == 1
+
+ def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test pinning a conversation with an end user.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create an end user
+ end_user = self._create_test_end_user(db_session_with_containers, app)
+
+ # Create a conversation for the end user
+ conversation = self._create_test_conversation(db_session_with_containers, app, end_user, fake)
+
+ # Pin the conversation
+ WebConversationService.pin(app, conversation.id, end_user)
+
+ # Verify the conversation was pinned
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "end_user",
+ PinnedConversation.created_by == end_user.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is not None
+ assert pinned_conversation.app_id == app.id
+ assert pinned_conversation.conversation_id == conversation.id
+ assert pinned_conversation.created_by_role == "end_user"
+ assert pinned_conversation.created_by == end_user.id
+
+ def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful unpinning of a conversation.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Pin the conversation first
+ WebConversationService.pin(app, conversation.id, account)
+
+ # Verify it was pinned
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is not None
+
+ # Unpin the conversation
+ WebConversationService.unpin(app, conversation.id, account)
+
+ # Verify it was unpinned
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is None
+
+ def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test unpinning a conversation that is not pinned (should not cause error).
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Try to unpin a conversation that was never pinned
+ WebConversationService.unpin(app, conversation.id, account)
+
+ # Verify no pinned conversation record exists
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is None
+
+ def test_pagination_by_last_id_user_required_error(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test that pagination_by_last_id raises ValueError when user is None.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Test with None user
+ with pytest.raises(ValueError, match="User is required"):
+ WebConversationService.pagination_by_last_id(
+ session=db_session_with_containers,
+ app_model=app,
+ user=None,
+ last_id=None,
+ limit=10,
+ invoke_from=InvokeFrom.WEB_APP,
+ pinned=None,
+ sort_by="-updated_at",
+ )
+
+ def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test that pin method returns early when user is None.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Try to pin with None user
+ WebConversationService.pin(app, conversation.id, None)
+
+ # Verify no pinned conversation was created
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is None
+
+ def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test that unpin method returns early when user is None.
+ """
+ fake = Faker()
+ app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Create a conversation
+ conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
+
+ # Pin the conversation first
+ WebConversationService.pin(app, conversation.id, account)
+
+ # Verify it was pinned
+ from extensions.ext_database import db
+
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is not None
+
+ # Try to unpin with None user
+ WebConversationService.unpin(app, conversation.id, None)
+
+ # Verify the conversation is still pinned
+ pinned_conversation = (
+ db.session.query(PinnedConversation)
+ .where(
+ PinnedConversation.app_id == app.id,
+ PinnedConversation.conversation_id == conversation.id,
+ PinnedConversation.created_by_role == "account",
+ PinnedConversation.created_by == account.id,
+ )
+ .first()
+ )
+
+ assert pinned_conversation is not None
diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py
new file mode 100644
index 0000000000..666b083ba6
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py
@@ -0,0 +1,877 @@
+from unittest.mock import patch
+
+import pytest
+from faker import Faker
+from werkzeug.exceptions import NotFound, Unauthorized
+
+from libs.password import hash_password
+from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
+from models.model import App, Site
+from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
+from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
+
+
+class TestWebAppAuthService:
+ """Integration tests for WebAppAuthService using testcontainers."""
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.webapp_auth_service.PassportService") as mock_passport_service,
+ patch("services.webapp_auth_service.TokenManager") as mock_token_manager,
+ patch("services.webapp_auth_service.send_email_code_login_mail_task") as mock_mail_task,
+ patch("services.webapp_auth_service.AppService") as mock_app_service,
+ patch("services.webapp_auth_service.EnterpriseService") as mock_enterprise_service,
+ ):
+ # Setup default mock returns
+ mock_passport_service.return_value.issue.return_value = "mock_jwt_token"
+ mock_token_manager.generate_token.return_value = "mock_token"
+ mock_token_manager.get_token_data.return_value = {"code": "123456"}
+ mock_mail_task.delay.return_value = None
+ mock_app_service.get_app_id_by_code.return_value = "mock_app_id"
+ mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type(
+ "MockWebAppAuth", (), {"access_mode": "private"}
+ )()
+ mock_enterprise_service.WebAppAuth.get_app_access_mode_by_code.return_value = type(
+ "MockWebAppAuth", (), {"access_mode": "private"}
+ )()
+
+ yield {
+ "passport_service": mock_passport_service,
+ "token_manager": mock_token_manager,
+ "mail_task": mock_mail_task,
+ "app_service": mock_app_service,
+ "enterprise_service": mock_enterprise_service,
+ }
+
+ def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Helper method to create a test account and tenant for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+
+ Returns:
+ tuple: (account, tenant) - Created account and tenant instances
+ """
+ fake = Faker()
+
+ # Create account
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status="active",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Create tenant for the account
+ tenant = Tenant(
+ name=fake.company(),
+ status="normal",
+ )
+ db.session.add(tenant)
+ db.session.commit()
+
+ # Create tenant-account join
+ join = TenantAccountJoin(
+ tenant_id=tenant.id,
+ account_id=account.id,
+ role=TenantAccountRole.OWNER.value,
+ current=True,
+ )
+ db.session.add(join)
+ db.session.commit()
+
+ # Set current tenant for account
+ account.current_tenant = tenant
+
+ return account, tenant
+
+ def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Helper method to create a test account with password for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+
+ Returns:
+ tuple: (account, tenant, password) - Created account, tenant and password
+ """
+ fake = Faker()
+ password = fake.password(length=12)
+
+ # Create account with password
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status="active",
+ )
+
+ # Hash password
+ salt = b"test_salt_16_bytes"
+ password_hash = hash_password(password, salt)
+
+ # Convert to base64 for storage
+ import base64
+
+ account.password = base64.b64encode(password_hash).decode()
+ account.password_salt = base64.b64encode(salt).decode()
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Create tenant for the account
+ tenant = Tenant(
+ name=fake.company(),
+ status="normal",
+ )
+ db.session.add(tenant)
+ db.session.commit()
+
+ # Create tenant-account join
+ join = TenantAccountJoin(
+ tenant_id=tenant.id,
+ account_id=account.id,
+ role=TenantAccountRole.OWNER.value,
+ current=True,
+ )
+ db.session.add(join)
+ db.session.commit()
+
+ # Set current tenant for account
+ account.current_tenant = tenant
+
+ return account, tenant, password
+
+ def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant):
+ """
+ Helper method to create a test app and site for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ tenant: Tenant instance to associate with
+
+ Returns:
+ tuple: (app, site) - Created app and site instances
+ """
+ fake = Faker()
+
+ # Create app
+ app = App(
+ tenant_id=tenant.id,
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ mode="chat",
+ icon_type="emoji",
+ icon="🤖",
+ icon_background="#FF6B6B",
+ api_rph=100,
+ api_rpm=10,
+ enable_site=True,
+ enable_api=True,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(app)
+ db.session.commit()
+
+ # Create site
+ site = Site(
+ app_id=app.id,
+ title=fake.company(),
+ code=fake.unique.lexify(text="??????"),
+ description=fake.text(max_nb_chars=100),
+ default_language="en-US",
+ status="normal",
+ customize_token_strategy="not_allow",
+ )
+ db.session.add(site)
+ db.session.commit()
+
+ return app, site
+
+ def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful authentication with valid email and password.
+
+ This test verifies:
+ - Proper authentication with valid credentials
+ - Correct account return
+ - Database state consistency
+ """
+ # Arrange: Create test data
+ account, tenant, password = self._create_test_account_with_password(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute authentication
+ result = WebAppAuthService.authenticate(account.email, password)
+
+ # Assert: Verify successful authentication
+ assert result is not None
+ assert result.id == account.id
+ assert result.email == account.email
+ assert result.name == account.name
+ assert result.status == AccountStatus.ACTIVE.value
+
+ # Verify database state
+ from extensions.ext_database import db
+
+ db.session.refresh(result)
+ assert result.id is not None
+ assert result.password is not None
+ assert result.password_salt is not None
+
+ def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test authentication with non-existent email.
+
+ This test verifies:
+ - Proper error handling for non-existent accounts
+ - Correct exception type and message
+ """
+ # Arrange: Use non-existent email
+ fake = Faker()
+ non_existent_email = fake.email()
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(AccountNotFoundError):
+ WebAppAuthService.authenticate(non_existent_email, "any_password")
+
+ def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test authentication with banned account.
+
+ This test verifies:
+ - Proper error handling for banned accounts
+ - Correct exception type and message
+ """
+ # Arrange: Create banned account
+ fake = Faker()
+ password = fake.password(length=12)
+
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status=AccountStatus.BANNED.value,
+ )
+
+ # Hash password
+ salt = b"test_salt_16_bytes"
+ password_hash = hash_password(password, salt)
+
+ # Convert to base64 for storage
+ import base64
+
+ account.password = base64.b64encode(password_hash).decode()
+ account.password_salt = base64.b64encode(salt).decode()
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(AccountLoginError) as exc_info:
+ WebAppAuthService.authenticate(account.email, password)
+
+ assert "Account is banned." in str(exc_info.value)
+
+ def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test authentication with invalid password.
+
+ This test verifies:
+ - Proper error handling for invalid passwords
+ - Correct exception type and message
+ """
+ # Arrange: Create account with password
+ account, tenant, correct_password = self._create_test_account_with_password(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act & Assert: Verify proper error handling with wrong password
+ with pytest.raises(AccountPasswordError) as exc_info:
+ WebAppAuthService.authenticate(account.email, "wrong_password")
+
+ assert "Invalid email or password." in str(exc_info.value)
+
+ def test_authenticate_account_without_password(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test authentication for account without password.
+
+ This test verifies:
+ - Proper error handling for accounts without password
+ - Correct exception type and message
+ """
+ # Arrange: Create account without password
+ fake = Faker()
+
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status="active",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(AccountPasswordError) as exc_info:
+ WebAppAuthService.authenticate(account.email, "any_password")
+
+ assert "Invalid email or password." in str(exc_info.value)
+
+ def test_login_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful login and JWT token generation.
+
+ This test verifies:
+ - Proper JWT token generation
+ - Correct token format and content
+ - Mock service integration
+ """
+ # Arrange: Create test account
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute login
+ result = WebAppAuthService.login(account)
+
+ # Assert: Verify successful login
+ assert result is not None
+ assert result == "mock_jwt_token"
+
+ # Verify mock service was called correctly
+ mock_external_service_dependencies["passport_service"].return_value.issue.assert_called_once()
+ call_args = mock_external_service_dependencies["passport_service"].return_value.issue.call_args[0][0]
+
+ assert call_args["sub"] == "Web API Passport"
+ assert call_args["user_id"] == account.id
+ assert call_args["session_id"] == account.email
+ assert call_args["token_source"] == "webapp_login_token"
+ assert call_args["auth_type"] == "internal"
+ assert "exp" in call_args
+
+ def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful user retrieval through email.
+
+ This test verifies:
+ - Proper user retrieval by email
+ - Correct account return
+ - Database state consistency
+ """
+ # Arrange: Create test data
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute user retrieval
+ result = WebAppAuthService.get_user_through_email(account.email)
+
+ # Assert: Verify successful retrieval
+ assert result is not None
+ assert result.id == account.id
+ assert result.email == account.email
+ assert result.name == account.name
+ assert result.status == AccountStatus.ACTIVE.value
+
+ # Verify database state
+ from extensions.ext_database import db
+
+ db.session.refresh(result)
+ assert result.id is not None
+
+ def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test user retrieval with non-existent email.
+
+ This test verifies:
+ - Proper handling for non-existent users
+ - Correct return value (None)
+ """
+ # Arrange: Use non-existent email
+ fake = Faker()
+ non_existent_email = fake.email()
+
+ # Act: Execute user retrieval
+ result = WebAppAuthService.get_user_through_email(non_existent_email)
+
+ # Assert: Verify proper handling
+ assert result is None
+
+ def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test user retrieval with banned account.
+
+ This test verifies:
+ - Proper error handling for banned accounts
+ - Correct exception type and message
+ """
+ # Arrange: Create banned account
+ fake = Faker()
+
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status=AccountStatus.BANNED.value,
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(Unauthorized) as exc_info:
+ WebAppAuthService.get_user_through_email(account.email)
+
+ assert "Account is banned." in str(exc_info.value)
+
+ def test_send_email_code_login_email_with_account(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test sending email code login email with account.
+
+ This test verifies:
+ - Proper email code generation
+ - Token generation with correct data
+ - Mail task scheduling
+ - Mock service integration
+ """
+ # Arrange: Create test account
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Act: Execute email code login email sending
+ result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
+
+ # Assert: Verify successful email sending
+ assert result is not None
+ assert result == "mock_token"
+
+ # Verify mock services were called correctly
+ mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
+ mock_external_service_dependencies["mail_task"].delay.assert_called_once()
+
+ # Verify token generation parameters
+ token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
+ assert token_call_args[1]["account"] == account
+ assert token_call_args[1]["email"] == account.email
+ assert token_call_args[1]["token_type"] == "email_code_login"
+ assert "code" in token_call_args[1]["additional_data"]
+
+ # Verify mail task parameters
+ mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
+ assert mail_call_args[1]["language"] == "en-US"
+ assert mail_call_args[1]["to"] == account.email
+ assert "code" in mail_call_args[1]
+
+ def test_send_email_code_login_email_with_email_only(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test sending email code login email with email only.
+
+ This test verifies:
+ - Proper email code generation without account
+ - Token generation with email only
+ - Mail task scheduling
+ - Mock service integration
+ """
+ # Arrange: Use test email
+ fake = Faker()
+ test_email = fake.email()
+
+ # Act: Execute email code login email sending
+ result = WebAppAuthService.send_email_code_login_email(email=test_email, language="zh-Hans")
+
+ # Assert: Verify successful email sending
+ assert result is not None
+ assert result == "mock_token"
+
+ # Verify mock services were called correctly
+ mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
+ mock_external_service_dependencies["mail_task"].delay.assert_called_once()
+
+ # Verify token generation parameters
+ token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
+ assert token_call_args[1]["account"] is None
+ assert token_call_args[1]["email"] == test_email
+ assert token_call_args[1]["token_type"] == "email_code_login"
+ assert "code" in token_call_args[1]["additional_data"]
+
+ # Verify mail task parameters
+ mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
+ assert mail_call_args[1]["language"] == "zh-Hans"
+ assert mail_call_args[1]["to"] == test_email
+ assert "code" in mail_call_args[1]
+
+ def test_send_email_code_login_email_no_email_provided(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test sending email code login email without providing email.
+
+ This test verifies:
+ - Proper error handling when no email is provided
+ - Correct exception type and message
+ """
+ # Arrange: No email provided
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebAppAuthService.send_email_code_login_email()
+
+ assert "Email must be provided." in str(exc_info.value)
+
+ def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful retrieval of email code login data.
+
+ This test verifies:
+ - Proper token data retrieval
+ - Correct data format
+ - Mock service integration
+ """
+ # Arrange: Setup mock return
+ expected_data = {"code": "123456", "email": "test@example.com"}
+ mock_external_service_dependencies["token_manager"].get_token_data.return_value = expected_data
+
+ # Act: Execute data retrieval
+ result = WebAppAuthService.get_email_code_login_data("mock_token")
+
+ # Assert: Verify successful retrieval
+ assert result is not None
+ assert result == expected_data
+ assert result["code"] == "123456"
+ assert result["email"] == "test@example.com"
+
+ # Verify mock service was called correctly
+ mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
+ "mock_token", "email_code_login"
+ )
+
+ def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test email code login data retrieval when no data exists.
+
+ This test verifies:
+ - Proper handling when no token data exists
+ - Correct return value (None)
+ - Mock service integration
+ """
+ # Arrange: Setup mock return for no data
+ mock_external_service_dependencies["token_manager"].get_token_data.return_value = None
+
+ # Act: Execute data retrieval
+ result = WebAppAuthService.get_email_code_login_data("invalid_token")
+
+ # Assert: Verify proper handling
+ assert result is None
+
+ # Verify mock service was called correctly
+ mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
+ "invalid_token", "email_code_login"
+ )
+
+ def test_revoke_email_code_login_token_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful revocation of email code login token.
+
+ This test verifies:
+ - Proper token revocation
+ - Mock service integration
+ """
+ # Arrange: Setup mock
+
+ # Act: Execute token revocation
+ WebAppAuthService.revoke_email_code_login_token("mock_token")
+
+ # Assert: Verify mock service was called correctly
+ mock_external_service_dependencies["token_manager"].revoke_token.assert_called_once_with(
+ "mock_token", "email_code_login"
+ )
+
+ def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful end user creation.
+
+ This test verifies:
+ - Proper end user creation with valid app code
+ - Correct database state after creation
+ - Proper relationship establishment
+ - Mock service integration
+ """
+ # Arrange: Create test data
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+ app, site = self._create_test_app_and_site(
+ db_session_with_containers, mock_external_service_dependencies, tenant
+ )
+
+ # Act: Execute end user creation
+ result = WebAppAuthService.create_end_user(site.code, "test@example.com")
+
+ # Assert: Verify successful creation
+ assert result is not None
+ assert result.tenant_id == app.tenant_id
+ assert result.app_id == app.id
+ assert result.type == "browser"
+ assert result.is_anonymous is False
+ assert result.session_id == "test@example.com"
+ assert result.name == "enterpriseuser"
+ assert result.external_user_id == "enterpriseuser"
+
+ # Verify database state
+ from extensions.ext_database import db
+
+ db.session.refresh(result)
+ assert result.id is not None
+ assert result.created_at is not None
+ assert result.updated_at is not None
+
+ def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test end user creation with non-existent site code.
+
+ This test verifies:
+ - Proper error handling for non-existent sites
+ - Correct exception type and message
+ """
+ # Arrange: Use non-existent site code
+ fake = Faker()
+ non_existent_code = fake.unique.lexify(text="??????")
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ WebAppAuthService.create_end_user(non_existent_code, "test@example.com")
+
+ assert "Site not found." in str(exc_info.value)
+
+ def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test end user creation when app is not found.
+
+ This test verifies:
+ - Proper error handling when app is missing
+ - Correct exception type and message
+ """
+ # Arrange: Create site without app
+ fake = Faker()
+ tenant = Tenant(
+ name=fake.company(),
+ status="normal",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(tenant)
+ db.session.commit()
+
+ site = Site(
+ app_id="00000000-0000-0000-0000-000000000000",
+ title=fake.company(),
+ code=fake.unique.lexify(text="??????"),
+ description=fake.text(max_nb_chars=100),
+ default_language="en-US",
+ status="normal",
+ customize_token_strategy="not_allow",
+ )
+ db.session.add(site)
+ db.session.commit()
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(NotFound) as exc_info:
+ WebAppAuthService.create_end_user(site.code, "test@example.com")
+
+ assert "App not found." in str(exc_info.value)
+
+ def test_is_app_require_permission_check_with_access_mode_private(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test permission check requirement for private access mode.
+
+ This test verifies:
+ - Proper permission check requirement for private mode
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup test with private access mode
+
+ # Act: Execute permission check requirement test
+ result = WebAppAuthService.is_app_require_permission_check(access_mode="private")
+
+ # Assert: Verify correct result
+ assert result is True
+
+ def test_is_app_require_permission_check_with_access_mode_public(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test permission check requirement for public access mode.
+
+ This test verifies:
+ - Proper permission check requirement for public mode
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup test with public access mode
+
+ # Act: Execute permission check requirement test
+ result = WebAppAuthService.is_app_require_permission_check(access_mode="public")
+
+ # Assert: Verify correct result
+ assert result is False
+
+ def test_is_app_require_permission_check_with_app_code(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test permission check requirement using app code.
+
+ This test verifies:
+ - Proper permission check requirement using app code
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup mock for app service
+ mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
+
+ # Act: Execute permission check requirement test
+ result = WebAppAuthService.is_app_require_permission_check(app_code="mock_app_code")
+
+ # Assert: Verify correct result
+ assert result is True
+
+ # Verify mock service was called correctly
+ mock_external_service_dependencies["app_service"].get_app_id_by_code.assert_called_once_with("mock_app_code")
+ mock_external_service_dependencies[
+ "enterprise_service"
+ ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
+
+ def test_is_app_require_permission_check_no_parameters(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test permission check requirement with no parameters.
+
+ This test verifies:
+ - Proper error handling when no parameters provided
+ - Correct exception type and message
+ """
+ # Arrange: No parameters provided
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebAppAuthService.is_app_require_permission_check()
+
+ assert "Either app_code or app_id must be provided." in str(exc_info.value)
+
+ def test_get_app_auth_type_with_access_mode_public(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test app authentication type for public access mode.
+
+ This test verifies:
+ - Proper authentication type determination for public mode
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup test with public access mode
+
+ # Act: Execute authentication type determination
+ result = WebAppAuthService.get_app_auth_type(access_mode="public")
+
+ # Assert: Verify correct result
+ assert result == WebAppAuthType.PUBLIC
+
+ def test_get_app_auth_type_with_access_mode_private(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test app authentication type for private access mode.
+
+ This test verifies:
+ - Proper authentication type determination for private mode
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup test with private access mode
+
+ # Act: Execute authentication type determination
+ result = WebAppAuthService.get_app_auth_type(access_mode="private")
+
+ # Assert: Verify correct result
+ assert result == WebAppAuthType.INTERNAL
+
+ def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test app authentication type using app code.
+
+ This test verifies:
+ - Proper authentication type determination using app code
+ - Correct return value
+ - Mock service integration
+ """
+ # Arrange: Setup mock for enterprise service
+ mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
+ mock_external_service_dependencies[
+ "enterprise_service"
+ ].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth
+
+ # Act: Execute authentication type determination
+ result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
+
+ # Assert: Verify correct result
+ assert result == WebAppAuthType.EXTERNAL
+
+ # Verify mock service was called correctly
+ mock_external_service_dependencies[
+ "enterprise_service"
+ ].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code")
+
+ def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test app authentication type with no parameters.
+
+ This test verifies:
+ - Proper error handling when no parameters provided
+ - Correct exception type and message
+ """
+ # Arrange: No parameters provided
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebAppAuthService.get_app_auth_type()
+
+ assert "Either app_code or access_mode must be provided." in str(exc_info.value)
diff --git a/api/tests/test_containers_integration_tests/services/test_website_service.py b/api/tests/test_containers_integration_tests/services/test_website_service.py
new file mode 100644
index 0000000000..ec2f1556af
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_website_service.py
@@ -0,0 +1,1437 @@
+from datetime import datetime
+from unittest.mock import MagicMock, patch
+
+import pytest
+from faker import Faker
+
+from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
+from services.website_service import (
+ CrawlOptions,
+ ScrapeRequest,
+ WebsiteCrawlApiRequest,
+ WebsiteCrawlStatusApiRequest,
+ WebsiteService,
+)
+
+
+class TestWebsiteService:
+ """Integration tests for WebsiteService using testcontainers."""
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.website_service.ApiKeyAuthService") as mock_api_key_auth_service,
+ patch("services.website_service.FirecrawlApp") as mock_firecrawl_app,
+ patch("services.website_service.WaterCrawlProvider") as mock_watercrawl_provider,
+ patch("services.website_service.requests") as mock_requests,
+ patch("services.website_service.redis_client") as mock_redis_client,
+ patch("services.website_service.storage") as mock_storage,
+ patch("services.website_service.encrypter") as mock_encrypter,
+ ):
+ # Setup default mock returns
+ mock_api_key_auth_service.get_auth_credentials.return_value = {
+ "config": {"api_key": "encrypted_api_key", "base_url": "https://api.example.com"}
+ }
+ mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
+
+ # Mock FirecrawlApp
+ mock_firecrawl_instance = MagicMock()
+ mock_firecrawl_instance.crawl_url.return_value = "test_job_id_123"
+ mock_firecrawl_instance.check_crawl_status.return_value = {
+ "status": "completed",
+ "total": 5,
+ "current": 5,
+ "data": [{"source_url": "https://example.com", "title": "Test Page"}],
+ }
+ mock_firecrawl_app.return_value = mock_firecrawl_instance
+
+ # Mock WaterCrawlProvider
+ mock_watercrawl_instance = MagicMock()
+ mock_watercrawl_instance.crawl_url.return_value = {"status": "active", "job_id": "watercrawl_job_123"}
+ mock_watercrawl_instance.get_crawl_status.return_value = {
+ "status": "completed",
+ "job_id": "watercrawl_job_123",
+ "total": 3,
+ "current": 3,
+ "data": [],
+ }
+ mock_watercrawl_instance.get_crawl_url_data.return_value = {
+ "title": "WaterCrawl Page",
+ "source_url": "https://example.com",
+ "description": "Test description",
+ "markdown": "# Test Content",
+ }
+ mock_watercrawl_instance.scrape_url.return_value = {
+ "title": "Scraped Page",
+ "content": "Test content",
+ "url": "https://example.com",
+ }
+ mock_watercrawl_provider.return_value = mock_watercrawl_instance
+
+ # Mock requests
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"code": 200, "data": {"taskId": "jina_job_123"}}
+ mock_requests.get.return_value = mock_response
+ mock_requests.post.return_value = mock_response
+
+ # Mock Redis
+ mock_redis_client.setex.return_value = None
+ mock_redis_client.get.return_value = str(datetime.now().timestamp())
+ mock_redis_client.delete.return_value = None
+
+ # Mock Storage
+ mock_storage.exists.return_value = False
+ mock_storage.load_once.return_value = None
+
+ yield {
+ "api_key_auth_service": mock_api_key_auth_service,
+ "firecrawl_app": mock_firecrawl_app,
+ "watercrawl_provider": mock_watercrawl_provider,
+ "requests": mock_requests,
+ "redis_client": mock_redis_client,
+ "storage": mock_storage,
+ "encrypter": mock_encrypter,
+ }
+
+ def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Helper method to create a test account with proper tenant setup.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+
+ Returns:
+ Account: Created account instance
+ """
+ fake = Faker()
+
+ # Create account
+ account = Account(
+ email=fake.email(),
+ name=fake.name(),
+ interface_language="en-US",
+ status="active",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(account)
+ db.session.commit()
+
+ # Create tenant for the account
+ tenant = Tenant(
+ name=fake.company(),
+ status="normal",
+ )
+ db.session.add(tenant)
+ db.session.commit()
+
+ # Create tenant-account join
+ join = TenantAccountJoin(
+ tenant_id=tenant.id,
+ account_id=account.id,
+ role=TenantAccountRole.OWNER.value,
+ current=True,
+ )
+ db.session.add(join)
+ db.session.commit()
+
+ # Set current tenant for account
+ account.current_tenant = tenant
+
+ return account
+
+ def test_document_create_args_validate_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful argument validation for document creation.
+
+ This test verifies:
+ - Valid arguments are accepted without errors
+ - All required fields are properly validated
+ - Optional fields are handled correctly
+ """
+ # Arrange: Prepare valid arguments
+ valid_args = {
+ "provider": "firecrawl",
+ "url": "https://example.com",
+ "options": {
+ "limit": 5,
+ "crawl_sub_pages": True,
+ "only_main_content": False,
+ "includes": "blog,news",
+ "excludes": "admin,private",
+ "max_depth": 3,
+ "use_sitemap": True,
+ },
+ }
+
+ # Act: Validate arguments
+ WebsiteService.document_create_args_validate(valid_args)
+
+ # Assert: No exception should be raised
+ # If we reach here, validation passed successfully
+
+ def test_document_create_args_validate_missing_provider(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test argument validation fails when provider is missing.
+
+ This test verifies:
+ - Missing provider raises ValueError
+ - Proper error message is provided
+ - Validation stops at first missing required field
+ """
+ # Arrange: Prepare arguments without provider
+ invalid_args = {"url": "https://example.com", "options": {"limit": 5, "crawl_sub_pages": True}}
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.document_create_args_validate(invalid_args)
+
+ assert "Provider is required" in str(exc_info.value)
+
+ def test_document_create_args_validate_missing_url(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test argument validation fails when URL is missing.
+
+ This test verifies:
+ - Missing URL raises ValueError
+ - Proper error message is provided
+ - Validation continues after provider check
+ """
+ # Arrange: Prepare arguments without URL
+ invalid_args = {"provider": "firecrawl", "options": {"limit": 5, "crawl_sub_pages": True}}
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.document_create_args_validate(invalid_args)
+
+ assert "URL is required" in str(exc_info.value)
+
+ def test_crawl_url_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful URL crawling with Firecrawl provider.
+
+ This test verifies:
+ - Firecrawl provider is properly initialized
+ - API credentials are retrieved and decrypted
+ - Crawl parameters are correctly formatted
+ - Job ID is returned with active status
+ - Redis cache is properly set
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+ fake = Faker()
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlApiRequest(
+ provider="firecrawl",
+ url="https://example.com",
+ options={
+ "limit": 10,
+ "crawl_sub_pages": True,
+ "only_main_content": True,
+ "includes": "blog,news",
+ "excludes": "admin,private",
+ "max_depth": 2,
+ "use_sitemap": True,
+ },
+ )
+
+ # Act: Execute crawl operation
+ result = WebsiteService.crawl_url(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["job_id"] == "test_job_id_123"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "firecrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+ mock_external_service_dependencies["firecrawl_app"].assert_called_once_with(
+ api_key="decrypted_api_key", base_url="https://api.example.com"
+ )
+
+ # Verify Redis cache was set
+ mock_external_service_dependencies["redis_client"].setex.assert_called_once()
+
+ def test_crawl_url_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful URL crawling with WaterCrawl provider.
+
+ This test verifies:
+ - WaterCrawl provider is properly initialized
+ - API credentials are retrieved and decrypted
+ - Crawl options are correctly passed to provider
+ - Provider returns expected response format
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlApiRequest(
+ provider="watercrawl",
+ url="https://example.com",
+ options={
+ "limit": 5,
+ "crawl_sub_pages": False,
+ "only_main_content": False,
+ "includes": None,
+ "excludes": None,
+ "max_depth": None,
+ "use_sitemap": False,
+ },
+ )
+
+ # Act: Execute crawl operation
+ result = WebsiteService.crawl_url(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["job_id"] == "watercrawl_job_123"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "watercrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+ mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with(
+ api_key="decrypted_api_key", base_url="https://api.example.com"
+ )
+
+ def test_crawl_url_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful URL crawling with JinaReader provider.
+
+ This test verifies:
+ - JinaReader provider handles single page crawling
+ - API credentials are retrieved and decrypted
+ - HTTP requests are made with proper headers
+ - Response is properly parsed and returned
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request for single page crawling
+ api_request = WebsiteCrawlApiRequest(
+ provider="jinareader",
+ url="https://example.com",
+ options={
+ "limit": 1,
+ "crawl_sub_pages": False,
+ "only_main_content": True,
+ "includes": None,
+ "excludes": None,
+ "max_depth": None,
+ "use_sitemap": False,
+ },
+ )
+
+ # Act: Execute crawl operation
+ result = WebsiteService.crawl_url(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["data"] is not None
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "jinareader"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify HTTP request was made
+ mock_external_service_dependencies["requests"].get.assert_called_once_with(
+ "https://r.jina.ai/https://example.com",
+ headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"},
+ )
+
+ def test_crawl_url_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test crawl operation fails with invalid provider.
+
+ This test verifies:
+ - Invalid provider raises ValueError
+ - Proper error message is provided
+ - Service handles unsupported providers gracefully
+ """
+ # Arrange: Create test account and prepare request with invalid provider
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request with invalid provider
+ api_request = WebsiteCrawlApiRequest(
+ provider="invalid_provider",
+ url="https://example.com",
+ options={"limit": 5, "crawl_sub_pages": False, "only_main_content": False},
+ )
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.crawl_url(api_request)
+
+ assert "Invalid provider" in str(exc_info.value)
+
+ def test_get_crawl_status_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful crawl status retrieval with Firecrawl provider.
+
+ This test verifies:
+ - Firecrawl status is properly retrieved
+ - API credentials are retrieved and decrypted
+ - Status data includes all required fields
+ - Redis cache is properly managed for completed jobs
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
+
+ # Act: Get crawl status
+ result = WebsiteService.get_crawl_status_typed(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "completed"
+ assert result["job_id"] == "test_job_id_123"
+ assert result["total"] == 5
+ assert result["current"] == 5
+ assert "data" in result
+ assert "time_consuming" in result
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "firecrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify Redis cache was accessed and cleaned up
+ mock_external_service_dependencies["redis_client"].get.assert_called_once()
+ mock_external_service_dependencies["redis_client"].delete.assert_called_once()
+
+ def test_get_crawl_status_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful crawl status retrieval with WaterCrawl provider.
+
+ This test verifies:
+ - WaterCrawl status is properly retrieved
+ - API credentials are retrieved and decrypted
+ - Provider returns expected status format
+ - All required status fields are present
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123")
+
+ # Act: Get crawl status
+ result = WebsiteService.get_crawl_status_typed(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "completed"
+ assert result["job_id"] == "watercrawl_job_123"
+ assert result["total"] == 3
+ assert result["current"] == 3
+ assert "data" in result
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "watercrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ def test_get_crawl_status_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful crawl status retrieval with JinaReader provider.
+
+ This test verifies:
+ - JinaReader status is properly retrieved
+ - API credentials are retrieved and decrypted
+ - HTTP requests are made with proper parameters
+ - Status data is properly formatted and returned
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123")
+
+ # Act: Get crawl status
+ result = WebsiteService.get_crawl_status_typed(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["job_id"] == "jina_job_123"
+ assert "total" in result
+ assert "current" in result
+ assert "data" in result
+ assert "time_consuming" in result
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "jinareader"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify HTTP request was made
+ mock_external_service_dependencies["requests"].post.assert_called_once()
+
+ def test_get_crawl_status_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test crawl status retrieval fails with invalid provider.
+
+ This test verifies:
+ - Invalid provider raises ValueError
+ - Proper error message is provided
+ - Service handles unsupported providers gracefully
+ """
+ # Arrange: Create test account and prepare request with invalid provider
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request with invalid provider
+ api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123")
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_crawl_status_typed(api_request)
+
+ assert "Invalid provider" in str(exc_info.value)
+
+ def test_get_crawl_status_missing_credentials(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test crawl status retrieval fails when credentials are missing.
+
+ This test verifies:
+ - Missing credentials raises ValueError
+ - Proper error message is provided
+ - Service handles authentication failures gracefully
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Mock missing credentials
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_crawl_status_typed(api_request)
+
+ assert "No valid credentials found for the provider" in str(exc_info.value)
+
+ def test_get_crawl_status_missing_api_key(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test crawl status retrieval fails when API key is missing from config.
+
+ This test verifies:
+ - Missing API key raises ValueError
+ - Proper error message is provided
+ - Service handles configuration failures gracefully
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Mock missing API key in config
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = {
+ "config": {"base_url": "https://api.example.com"}
+ }
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_crawl_status_typed(api_request)
+
+ assert "API key not found in configuration" in str(exc_info.value)
+
+ def test_get_crawl_url_data_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful URL data retrieval with Firecrawl provider.
+
+ This test verifies:
+ - Firecrawl URL data is properly retrieved
+ - API credentials are retrieved and decrypted
+ - Data is returned for matching URL
+ - Storage fallback works when needed
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock storage to return existing data
+ mock_external_service_dependencies["storage"].exists.return_value = True
+ mock_external_service_dependencies["storage"].load_once.return_value = (
+ b"["
+ b'{"source_url": "https://example.com", "title": "Test Page", '
+ b'"description": "Test Description", "markdown": "# Test Content"}'
+ b"]"
+ )
+
+ # Act: Get URL data
+ result = WebsiteService.get_crawl_url_data(
+ job_id="test_job_id_123",
+ provider="firecrawl",
+ url="https://example.com",
+ tenant_id=account.current_tenant.id,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["source_url"] == "https://example.com"
+ assert result["title"] == "Test Page"
+ assert result["description"] == "Test Description"
+ assert result["markdown"] == "# Test Content"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "firecrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify storage was accessed
+ mock_external_service_dependencies["storage"].exists.assert_called_once()
+ mock_external_service_dependencies["storage"].load_once.assert_called_once()
+
+ def test_get_crawl_url_data_watercrawl_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful URL data retrieval with WaterCrawl provider.
+
+ This test verifies:
+ - WaterCrawl URL data is properly retrieved
+ - API credentials are retrieved and decrypted
+ - Provider returns expected data format
+ - All required data fields are present
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Act: Get URL data
+ result = WebsiteService.get_crawl_url_data(
+ job_id="watercrawl_job_123",
+ provider="watercrawl",
+ url="https://example.com",
+ tenant_id=account.current_tenant.id,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["title"] == "WaterCrawl Page"
+ assert result["source_url"] == "https://example.com"
+ assert result["description"] == "Test description"
+ assert result["markdown"] == "# Test Content"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "watercrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ def test_get_crawl_url_data_jinareader_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful URL data retrieval with JinaReader provider.
+
+ This test verifies:
+ - JinaReader URL data is properly retrieved
+ - API credentials are retrieved and decrypted
+ - HTTP requests are made with proper parameters
+ - Data is properly formatted and returned
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock successful response for JinaReader
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "code": 200,
+ "data": {
+ "title": "JinaReader Page",
+ "url": "https://example.com",
+ "description": "Test description",
+ "content": "# Test Content",
+ },
+ }
+ mock_external_service_dependencies["requests"].get.return_value = mock_response
+
+ # Act: Get URL data without job_id (single page scraping)
+ result = WebsiteService.get_crawl_url_data(
+ job_id="", provider="jinareader", url="https://example.com", tenant_id=account.current_tenant.id
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["title"] == "JinaReader Page"
+ assert result["url"] == "https://example.com"
+ assert result["description"] == "Test description"
+ assert result["content"] == "# Test Content"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "jinareader"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify HTTP request was made
+ mock_external_service_dependencies["requests"].get.assert_called_once_with(
+ "https://r.jina.ai/https://example.com",
+ headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"},
+ )
+
+ def test_get_scrape_url_data_firecrawl_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful URL scraping with Firecrawl provider.
+
+ This test verifies:
+ - Firecrawl scraping is properly executed
+ - API credentials are retrieved and decrypted
+ - Scraping parameters are correctly passed
+ - Scraped data is returned in expected format
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock FirecrawlApp scraping response
+ mock_firecrawl_instance = MagicMock()
+ mock_firecrawl_instance.scrape_url.return_value = {
+ "title": "Scraped Page Title",
+ "content": "This is the scraped content",
+ "url": "https://example.com",
+ "description": "Page description",
+ }
+ mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
+
+ # Act: Scrape URL
+ result = WebsiteService.get_scrape_url_data(
+ provider="firecrawl", url="https://example.com", tenant_id=account.current_tenant.id, only_main_content=True
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["title"] == "Scraped Page Title"
+ assert result["content"] == "This is the scraped content"
+ assert result["url"] == "https://example.com"
+ assert result["description"] == "Page description"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "firecrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify FirecrawlApp was called with correct parameters
+ mock_external_service_dependencies["firecrawl_app"].assert_called_once_with(
+ api_key="decrypted_api_key", base_url="https://api.example.com"
+ )
+ mock_firecrawl_instance.scrape_url.assert_called_once_with(
+ url="https://example.com", params={"onlyMainContent": True}
+ )
+
+ def test_get_scrape_url_data_watercrawl_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful URL scraping with WaterCrawl provider.
+
+ This test verifies:
+ - WaterCrawl scraping is properly executed
+ - API credentials are retrieved and decrypted
+ - Provider returns expected scraping format
+ - All required data fields are present
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Act: Scrape URL
+ result = WebsiteService.get_scrape_url_data(
+ provider="watercrawl",
+ url="https://example.com",
+ tenant_id=account.current_tenant.id,
+ only_main_content=False,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["title"] == "Scraped Page"
+ assert result["content"] == "Test content"
+ assert result["url"] == "https://example.com"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "watercrawl"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify WaterCrawlProvider was called with correct parameters
+ mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with(
+ api_key="decrypted_api_key", base_url="https://api.example.com"
+ )
+
+ def test_get_scrape_url_data_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test URL scraping fails with invalid provider.
+
+ This test verifies:
+ - Invalid provider raises ValueError
+ - Proper error message is provided
+ - Service handles unsupported providers gracefully
+ """
+ # Arrange: Create test account and prepare request with invalid provider
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_scrape_url_data(
+ provider="invalid_provider",
+ url="https://example.com",
+ tenant_id=account.current_tenant.id,
+ only_main_content=False,
+ )
+
+ assert "Invalid provider" in str(exc_info.value)
+
+ def test_crawl_options_include_exclude_paths(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test CrawlOptions include and exclude path methods.
+
+ This test verifies:
+ - Include paths are properly parsed from comma-separated string
+ - Exclude paths are properly parsed from comma-separated string
+ - Empty or None values are handled correctly
+ - Path lists are returned in expected format
+ """
+ # Arrange: Create CrawlOptions with various path configurations
+ options_with_paths = CrawlOptions(includes="blog,news,articles", excludes="admin,private,test")
+
+ options_without_paths = CrawlOptions(includes=None, excludes="")
+
+ # Act: Get include and exclude paths
+ include_paths = options_with_paths.get_include_paths()
+ exclude_paths = options_with_paths.get_exclude_paths()
+
+ empty_include_paths = options_without_paths.get_include_paths()
+ empty_exclude_paths = options_without_paths.get_exclude_paths()
+
+ # Assert: Verify path parsing
+ assert include_paths == ["blog", "news", "articles"]
+ assert exclude_paths == ["admin", "private", "test"]
+ assert empty_include_paths == []
+ assert empty_exclude_paths == []
+
+ def test_website_crawl_api_request_conversion(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test WebsiteCrawlApiRequest conversion to CrawlRequest.
+
+ This test verifies:
+ - API request is properly converted to internal CrawlRequest
+ - All options are correctly mapped
+ - Default values are applied when options are missing
+ - Conversion maintains data integrity
+ """
+ # Arrange: Create API request with various options
+ api_request = WebsiteCrawlApiRequest(
+ provider="firecrawl",
+ url="https://example.com",
+ options={
+ "limit": 10,
+ "crawl_sub_pages": True,
+ "only_main_content": True,
+ "includes": "blog,news",
+ "excludes": "admin,private",
+ "max_depth": 3,
+ "use_sitemap": False,
+ },
+ )
+
+ # Act: Convert to CrawlRequest
+ crawl_request = api_request.to_crawl_request()
+
+ # Assert: Verify conversion
+ assert crawl_request.url == "https://example.com"
+ assert crawl_request.provider == "firecrawl"
+ assert crawl_request.options.limit == 10
+ assert crawl_request.options.crawl_sub_pages is True
+ assert crawl_request.options.only_main_content is True
+ assert crawl_request.options.includes == "blog,news"
+ assert crawl_request.options.excludes == "admin,private"
+ assert crawl_request.options.max_depth == 3
+ assert crawl_request.options.use_sitemap is False
+
+ def test_website_crawl_api_request_from_args(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test WebsiteCrawlApiRequest creation from Flask arguments.
+
+ This test verifies:
+ - Request is properly created from parsed arguments
+ - Required fields are validated
+ - Optional fields are handled correctly
+ - Validation errors are properly raised
+ """
+ # Arrange: Prepare valid arguments
+ valid_args = {"provider": "watercrawl", "url": "https://example.com", "options": {"limit": 5}}
+
+ # Act: Create request from args
+ request = WebsiteCrawlApiRequest.from_args(valid_args)
+
+ # Assert: Verify request creation
+ assert request.provider == "watercrawl"
+ assert request.url == "https://example.com"
+ assert request.options == {"limit": 5}
+
+ # Test missing provider
+ invalid_args = {"url": "https://example.com", "options": {}}
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteCrawlApiRequest.from_args(invalid_args)
+ assert "Provider is required" in str(exc_info.value)
+
+ # Test missing URL
+ invalid_args = {"provider": "watercrawl", "options": {}}
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteCrawlApiRequest.from_args(invalid_args)
+ assert "URL is required" in str(exc_info.value)
+
+ # Test missing options
+ invalid_args = {"provider": "watercrawl", "url": "https://example.com"}
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteCrawlApiRequest.from_args(invalid_args)
+ assert "Options are required" in str(exc_info.value)
+
+ def test_crawl_url_jinareader_sub_pages_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful URL crawling with JinaReader provider for sub-pages.
+
+ This test verifies:
+ - JinaReader provider handles sub-page crawling correctly
+ - HTTP POST request is made with proper parameters
+ - Job ID is returned for multi-page crawling
+ - All required parameters are passed correctly
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request for sub-page crawling
+ api_request = WebsiteCrawlApiRequest(
+ provider="jinareader",
+ url="https://example.com",
+ options={
+ "limit": 5,
+ "crawl_sub_pages": True,
+ "only_main_content": False,
+ "includes": None,
+ "excludes": None,
+ "max_depth": None,
+ "use_sitemap": True,
+ },
+ )
+
+ # Act: Execute crawl operation
+ result = WebsiteService.crawl_url(api_request)
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["job_id"] == "jina_job_123"
+
+ # Verify external service interactions
+ mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with(
+ account.current_tenant.id, "website", "jinareader"
+ )
+ mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with(
+ tenant_id=account.current_tenant.id, token="encrypted_api_key"
+ )
+
+ # Verify HTTP POST request was made for sub-page crawling
+ mock_external_service_dependencies["requests"].post.assert_called_once_with(
+ "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
+ json={"url": "https://example.com", "maxPages": 5, "useSitemap": True},
+ headers={"Content-Type": "application/json", "Authorization": "Bearer decrypted_api_key"},
+ )
+
+ def test_crawl_url_jinareader_failed_response(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test JinaReader crawling fails when API returns error.
+
+ This test verifies:
+ - Failed API response raises ValueError
+ - Proper error message is provided
+ - Service handles API failures gracefully
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock failed response
+ mock_failed_response = MagicMock()
+ mock_failed_response.json.return_value = {"code": 500, "error": "Internal server error"}
+ mock_external_service_dependencies["requests"].get.return_value = mock_failed_response
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlApiRequest(
+ provider="jinareader",
+ url="https://example.com",
+ options={"limit": 1, "crawl_sub_pages": False, "only_main_content": True},
+ )
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.crawl_url(api_request)
+
+ assert "Failed to crawl" in str(exc_info.value)
+
+ def test_get_crawl_status_firecrawl_active_job(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test Firecrawl status retrieval for active (not completed) job.
+
+ This test verifies:
+ - Active job status is properly returned
+ - Redis cache is not deleted for active jobs
+ - Time consuming is not calculated for active jobs
+ - All required status fields are present
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock active job status
+ mock_firecrawl_instance = MagicMock()
+ mock_firecrawl_instance.check_crawl_status.return_value = {
+ "status": "active",
+ "total": 10,
+ "current": 3,
+ "data": [],
+ }
+ mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
+
+ # Mock current_user for the test
+ with patch("services.website_service.current_user") as mock_current_user:
+ mock_current_user.current_tenant_id = account.current_tenant.id
+
+ # Create API request
+ api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123")
+
+ # Act: Get crawl status
+ result = WebsiteService.get_crawl_status_typed(api_request)
+
+ # Assert: Verify active job status
+ assert result is not None
+ assert result["status"] == "active"
+ assert result["job_id"] == "active_job_123"
+ assert result["total"] == 10
+ assert result["current"] == 3
+ assert "data" in result
+ assert "time_consuming" not in result
+
+ # Verify Redis cache was not accessed for active jobs
+ mock_external_service_dependencies["redis_client"].get.assert_not_called()
+ mock_external_service_dependencies["redis_client"].delete.assert_not_called()
+
+ def test_get_crawl_url_data_firecrawl_storage_fallback(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test Firecrawl URL data retrieval with storage fallback.
+
+ This test verifies:
+ - Storage fallback works when storage has data
+ - API call is not made when storage has data
+ - Data is properly parsed from storage
+ - Correct URL data is returned
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock storage to return existing data
+ mock_external_service_dependencies["storage"].exists.return_value = True
+ mock_external_service_dependencies["storage"].load_once.return_value = (
+ b"["
+ b'{"source_url": "https://example.com/page1", '
+ b'"title": "Page 1", "description": "Description 1", "markdown": "# Page 1"}, '
+ b'{"source_url": "https://example.com/page2", "title": "Page 2", '
+ b'"description": "Description 2", "markdown": "# Page 2"}'
+ b"]"
+ )
+
+ # Act: Get URL data for specific URL
+ result = WebsiteService.get_crawl_url_data(
+ job_id="test_job_id_123",
+ provider="firecrawl",
+ url="https://example.com/page1",
+ tenant_id=account.current_tenant.id,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["source_url"] == "https://example.com/page1"
+ assert result["title"] == "Page 1"
+ assert result["description"] == "Description 1"
+ assert result["markdown"] == "# Page 1"
+
+ # Verify storage was accessed
+ mock_external_service_dependencies["storage"].exists.assert_called_once()
+ mock_external_service_dependencies["storage"].load_once.assert_called_once()
+
+ def test_get_crawl_url_data_firecrawl_api_fallback(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test Firecrawl URL data retrieval with API fallback when storage is empty.
+
+ This test verifies:
+ - API fallback works when storage has no data
+ - FirecrawlApp is called to get data
+ - Completed job status is checked
+ - Data is returned from API response
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock storage to return no data
+ mock_external_service_dependencies["storage"].exists.return_value = False
+
+ # Mock FirecrawlApp for API fallback
+ mock_firecrawl_instance = MagicMock()
+ mock_firecrawl_instance.check_crawl_status.return_value = {
+ "status": "completed",
+ "data": [
+ {
+ "source_url": "https://example.com/api_page",
+ "title": "API Page",
+ "description": "API Description",
+ "markdown": "# API Content",
+ }
+ ],
+ }
+ mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
+
+ # Act: Get URL data
+ result = WebsiteService.get_crawl_url_data(
+ job_id="test_job_id_123",
+ provider="firecrawl",
+ url="https://example.com/api_page",
+ tenant_id=account.current_tenant.id,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["source_url"] == "https://example.com/api_page"
+ assert result["title"] == "API Page"
+ assert result["description"] == "API Description"
+ assert result["markdown"] == "# API Content"
+
+ # Verify API was called
+ mock_external_service_dependencies["firecrawl_app"].assert_called_once()
+
+ def test_get_crawl_url_data_firecrawl_incomplete_job(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test Firecrawl URL data retrieval fails for incomplete job.
+
+ This test verifies:
+ - Incomplete job raises ValueError
+ - Proper error message is provided
+ - Service handles incomplete jobs gracefully
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock storage to return no data
+ mock_external_service_dependencies["storage"].exists.return_value = False
+
+ # Mock incomplete job status
+ mock_firecrawl_instance = MagicMock()
+ mock_firecrawl_instance.check_crawl_status.return_value = {"status": "active", "data": []}
+ mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_crawl_url_data(
+ job_id="test_job_id_123",
+ provider="firecrawl",
+ url="https://example.com/page",
+ tenant_id=account.current_tenant.id,
+ )
+
+ assert "Crawl job is not completed" in str(exc_info.value)
+
+ def test_get_crawl_url_data_jinareader_with_job_id(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test JinaReader URL data retrieval with job ID for multi-page crawling.
+
+ This test verifies:
+ - JinaReader handles job ID-based data retrieval
+ - Status check is performed before data retrieval
+ - Processed data is properly formatted
+ - Correct URL data is returned
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock successful status response
+ mock_status_response = MagicMock()
+ mock_status_response.json.return_value = {
+ "code": 200,
+ "data": {
+ "status": "completed",
+ "processed": {
+ "https://example.com/page1": {
+ "data": {
+ "title": "Page 1",
+ "url": "https://example.com/page1",
+ "description": "Description 1",
+ "content": "# Content 1",
+ }
+ }
+ },
+ },
+ }
+ mock_external_service_dependencies["requests"].post.return_value = mock_status_response
+
+ # Act: Get URL data with job ID
+ result = WebsiteService.get_crawl_url_data(
+ job_id="jina_job_123",
+ provider="jinareader",
+ url="https://example.com/page1",
+ tenant_id=account.current_tenant.id,
+ )
+
+ # Assert: Verify successful operation
+ assert result is not None
+ assert result["title"] == "Page 1"
+ assert result["url"] == "https://example.com/page1"
+ assert result["description"] == "Description 1"
+ assert result["content"] == "# Content 1"
+
+ # Verify HTTP requests were made
+ assert mock_external_service_dependencies["requests"].post.call_count == 2
+
+ def test_get_crawl_url_data_jinareader_incomplete_job(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test JinaReader URL data retrieval fails for incomplete job.
+
+ This test verifies:
+ - Incomplete job raises ValueError
+ - Proper error message is provided
+ - Service handles incomplete jobs gracefully
+ """
+ # Arrange: Create test account and prepare request
+ account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
+
+ # Mock incomplete job status
+ mock_status_response = MagicMock()
+ mock_status_response.json.return_value = {"code": 200, "data": {"status": "active", "processed": {}}}
+ mock_external_service_dependencies["requests"].post.return_value = mock_status_response
+
+ # Act & Assert: Verify proper error handling
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteService.get_crawl_url_data(
+ job_id="jina_job_123",
+ provider="jinareader",
+ url="https://example.com/page",
+ tenant_id=account.current_tenant.id,
+ )
+
+ assert "Crawl job is not completed" in str(exc_info.value)
+
+ def test_crawl_options_default_values(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test CrawlOptions default values and initialization.
+
+ This test verifies:
+ - Default values are properly set
+ - Optional fields can be None
+ - Boolean fields have correct defaults
+ - Integer fields have correct defaults
+ """
+ # Arrange: Create CrawlOptions with minimal parameters
+ options = CrawlOptions()
+
+ # Assert: Verify default values
+ assert options.limit == 1
+ assert options.crawl_sub_pages is False
+ assert options.only_main_content is False
+ assert options.includes is None
+ assert options.excludes is None
+ assert options.max_depth is None
+ assert options.use_sitemap is True
+
+ # Test with custom values
+ custom_options = CrawlOptions(
+ limit=10,
+ crawl_sub_pages=True,
+ only_main_content=True,
+ includes="blog,news",
+ excludes="admin",
+ max_depth=3,
+ use_sitemap=False,
+ )
+
+ assert custom_options.limit == 10
+ assert custom_options.crawl_sub_pages is True
+ assert custom_options.only_main_content is True
+ assert custom_options.includes == "blog,news"
+ assert custom_options.excludes == "admin"
+ assert custom_options.max_depth == 3
+ assert custom_options.use_sitemap is False
+
+ def test_website_crawl_status_api_request_from_args(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test WebsiteCrawlStatusApiRequest creation from Flask arguments.
+
+ This test verifies:
+ - Request is properly created from parsed arguments
+ - Required fields are validated
+ - Job ID is properly handled
+ - Validation errors are properly raised
+ """
+ # Arrange: Prepare valid arguments
+ valid_args = {"provider": "firecrawl"}
+ job_id = "test_job_123"
+
+ # Act: Create request from args
+ request = WebsiteCrawlStatusApiRequest.from_args(valid_args, job_id)
+
+ # Assert: Verify request creation
+ assert request.provider == "firecrawl"
+ assert request.job_id == "test_job_123"
+
+ # Test missing provider
+ invalid_args = {}
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteCrawlStatusApiRequest.from_args(invalid_args, job_id)
+ assert "Provider is required" in str(exc_info.value)
+
+ # Test missing job ID
+ with pytest.raises(ValueError) as exc_info:
+ WebsiteCrawlStatusApiRequest.from_args(valid_args, "")
+ assert "Job ID is required" in str(exc_info.value)
+
+ def test_scrape_request_initialization(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test ScrapeRequest dataclass initialization and properties.
+
+ This test verifies:
+ - ScrapeRequest is properly initialized
+ - All fields are correctly set
+ - Boolean field works correctly
+ - String fields are properly assigned
+ """
+ # Arrange: Create ScrapeRequest
+ request = ScrapeRequest(
+ provider="firecrawl", url="https://example.com", tenant_id="tenant_123", only_main_content=True
+ )
+
+ # Assert: Verify initialization
+ assert request.provider == "firecrawl"
+ assert request.url == "https://example.com"
+ assert request.tenant_id == "tenant_123"
+ assert request.only_main_content is True
+
+ # Test with different values
+ request2 = ScrapeRequest(
+ provider="watercrawl", url="https://test.com", tenant_id="tenant_456", only_main_content=False
+ )
+
+ assert request2.provider == "watercrawl"
+ assert request2.url == "https://test.com"
+ assert request2.tenant_id == "tenant_456"
+ assert request2.only_main_content is False
diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
index f26be6702a..ac3c8e45c9 100644
--- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
+++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
@@ -1,9 +1,8 @@
-import datetime
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple
-from flask_restful import marshal
+from flask_restx import marshal
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
@@ -13,6 +12,7 @@ from controllers.console.app.workflow_draft_variable import (
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
+from libs.datetime_utils import naive_utc_now
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList
@@ -57,7 +57,7 @@ class TestWorkflowDraftVariableFields:
)
sys_var.id = str(uuid.uuid4())
- sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ sys_var.last_edited_at = naive_utc_now()
sys_var.visible = True
expected_without_value = OrderedDict(
@@ -88,7 +88,7 @@ class TestWorkflowDraftVariableFields:
)
node_var.id = str(uuid.uuid4())
- node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ node_var.last_edited_at = naive_utc_now()
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py
index 880a0d4940..aadd366762 100644
--- a/api/tests/unit_tests/core/mcp/client/test_sse.py
+++ b/api/tests/unit_tests/core/mcp/client/test_sse.py
@@ -1,3 +1,4 @@
+import contextlib
import json
import queue
import threading
@@ -124,13 +125,10 @@ def test_sse_client_connection_validation():
mock_event_source.iter_sse.return_value = [endpoint_event]
# Test connection
- try:
+ with contextlib.suppress(Exception):
with sse_client(test_url) as (read_queue, write_queue):
assert read_queue is not None
assert write_queue is not None
- except Exception as e:
- # Connection might fail due to mocking, but we're testing the validation logic
- pass
def test_sse_client_error_handling():
@@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration():
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
- try:
+ with contextlib.suppress(Exception):
with sse_client(
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
) as (read_queue, write_queue):
@@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration():
assert call_args is not None
timeout_arg = call_args[1]["timeout"]
assert timeout_arg.read == custom_sse_timeout
- except Exception:
- # Connection might fail due to mocking, but we tested the configuration
- pass
def test_sse_transport_endpoint_validation():
@@ -251,12 +246,10 @@ def test_sse_client_queue_cleanup():
# Mock connection that raises an exception
mock_sse_connect.side_effect = Exception("Connection failed")
- try:
+ with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq):
read_queue = rq
write_queue = wq
- except Exception:
- pass # Expected to fail
# Queues should be cleaned up even on exception
# Note: In real implementation, cleanup should put None to signal shutdown
@@ -283,11 +276,9 @@ def test_sse_client_headers_propagation():
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
- try:
+ with contextlib.suppress(Exception):
with sse_client(test_url, headers=custom_headers):
pass
- except Exception:
- pass # Expected due to mocking
# Verify headers were passed to client factory
mock_client_factory.assert_called_with(headers=custom_headers)
diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py
new file mode 100644
index 0000000000..c10f7b89c3
--- /dev/null
+++ b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py
@@ -0,0 +1,148 @@
+"""Tests for LLMUsage entity."""
+
+from decimal import Decimal
+
+from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
+
+
+class TestLLMUsage:
+ """Test cases for LLMUsage class."""
+
+ def test_from_metadata_with_all_tokens(self):
+ """Test from_metadata when all token types are provided."""
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 100,
+ "completion_tokens": 50,
+ "total_tokens": 150,
+ "prompt_unit_price": 0.001,
+ "completion_unit_price": 0.002,
+ "total_price": 0.2,
+ "currency": "USD",
+ "latency": 1.5,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 100
+ assert usage.completion_tokens == 50
+ assert usage.total_tokens == 150
+ assert usage.prompt_unit_price == Decimal("0.001")
+ assert usage.completion_unit_price == Decimal("0.002")
+ assert usage.total_price == Decimal("0.2")
+ assert usage.currency == "USD"
+ assert usage.latency == 1.5
+
+ def test_from_metadata_with_prompt_tokens_only(self):
+ """Test from_metadata when only prompt_tokens is provided."""
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 100,
+ "total_tokens": 100,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 100
+ assert usage.completion_tokens == 0
+ assert usage.total_tokens == 100
+
+ def test_from_metadata_with_completion_tokens_only(self):
+ """Test from_metadata when only completion_tokens is provided."""
+ metadata: LLMUsageMetadata = {
+ "completion_tokens": 50,
+ "total_tokens": 50,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 0
+ assert usage.completion_tokens == 50
+ assert usage.total_tokens == 50
+
+ def test_from_metadata_calculates_total_when_missing(self):
+ """Test from_metadata calculates total_tokens when not provided."""
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 100,
+ "completion_tokens": 50,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 100
+ assert usage.completion_tokens == 50
+ assert usage.total_tokens == 150 # Should be calculated
+
+ def test_from_metadata_with_total_but_no_completion(self):
+ """
+ Test from_metadata when total_tokens is provided but completion_tokens is 0.
+ This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
+ """
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 479,
+ "completion_tokens": 0,
+ "total_tokens": 521,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ # This is the key fix - prompt tokens should remain as prompt tokens
+ assert usage.prompt_tokens == 479
+ assert usage.completion_tokens == 0
+ assert usage.total_tokens == 521
+
+ def test_from_metadata_with_empty_metadata(self):
+ """Test from_metadata with empty metadata."""
+ metadata: LLMUsageMetadata = {}
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 0
+ assert usage.completion_tokens == 0
+ assert usage.total_tokens == 0
+ assert usage.currency == "USD"
+ assert usage.latency == 0.0
+
+ def test_from_metadata_preserves_zero_completion_tokens(self):
+ """
+ Test that zero completion_tokens are preserved when explicitly set.
+ This is important for agent nodes that only use prompt tokens.
+ """
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 1000,
+ "completion_tokens": 0,
+ "total_tokens": 1000,
+ "prompt_unit_price": 0.15,
+ "completion_unit_price": 0.60,
+ "prompt_price": 0.00015,
+ "completion_price": 0,
+ "total_price": 0.00015,
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_tokens == 1000
+ assert usage.completion_tokens == 0
+ assert usage.total_tokens == 1000
+ assert usage.prompt_price == Decimal("0.00015")
+ assert usage.completion_price == Decimal(0)
+ assert usage.total_price == Decimal("0.00015")
+
+ def test_from_metadata_with_decimal_values(self):
+ """Test from_metadata handles decimal values correctly."""
+ metadata: LLMUsageMetadata = {
+ "prompt_tokens": 100,
+ "completion_tokens": 50,
+ "total_tokens": 150,
+ "prompt_unit_price": "0.001",
+ "completion_unit_price": "0.002",
+ "prompt_price": "0.1",
+ "completion_price": "0.1",
+ "total_price": "0.2",
+ }
+
+ usage = LLMUsage.from_metadata(metadata)
+
+ assert usage.prompt_unit_price == Decimal("0.001")
+ assert usage.completion_unit_price == Decimal("0.002")
+ assert usage.prompt_price == Decimal("0.1")
+ assert usage.completion_price == Decimal("0.1")
+ assert usage.total_price == Decimal("0.2")
diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
index f6d22690d1..8abed0a3f9 100644
--- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
@@ -164,7 +164,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
- assert prompt_messages[3].content[1].data == files[0].remote_url
+ assert prompt_messages[3].content[0].data == files[0].remote_url
@pytest.fixture
diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py
index 450501c256..e7733b2317 100644
--- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py
@@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
-from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
@@ -13,6 +12,7 @@ import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
+from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@@ -56,7 +56,7 @@ def sample_workflow_execution():
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
@@ -199,7 +199,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
@@ -208,7 +208,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Save both executions
@@ -235,7 +235,7 @@ class TestCeleryWorkflowExecutionRepository:
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
repo.save(execution)
diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py
index b38d994f03..0c6fdc8f92 100644
--- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py
@@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
-from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
@@ -18,6 +17,7 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
+from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -65,7 +65,7 @@ def sample_workflow_node_execution():
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
@@ -263,7 +263,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
@@ -276,7 +276,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
# Save both executions
@@ -314,7 +314,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
@@ -327,7 +327,7 @@ class TestCeleryWorkflowNodeExecutionRepository:
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
# Save in random order
diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py
index 5146e82e8f..30f51902ef 100644
--- a/api/tests/unit_tests/core/repositories/test_factory.py
+++ b/api/tests/unit_tests/core/repositories/test_factory.py
@@ -2,19 +2,19 @@
Unit tests for the RepositoryFactory.
This module tests the factory pattern implementation for creating repository instances
-based on configuration, including error handling and validation.
+based on configuration, including error handling.
"""
from unittest.mock import MagicMock, patch
import pytest
-from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from libs.module_loading import import_string
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -23,98 +23,30 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom
class TestRepositoryFactory:
"""Test cases for RepositoryFactory."""
- def test_import_class_success(self):
+ def test_import_string_success(self):
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
- result = DifyCoreRepositoryFactory._import_class(class_path)
+ result = import_string(class_path)
assert result is MagicMock
- def test_import_class_invalid_path(self):
+ def test_import_string_invalid_path(self):
"""Test import with invalid module path."""
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory._import_class("invalid.module.path")
- assert "Cannot import repository class" in str(exc_info.value)
+ with pytest.raises(ImportError) as exc_info:
+ import_string("invalid.module.path")
+ assert "No module named" in str(exc_info.value)
- def test_import_class_invalid_class_name(self):
+ def test_import_string_invalid_class_name(self):
"""Test import with invalid class name."""
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
- assert "Cannot import repository class" in str(exc_info.value)
+ with pytest.raises(ImportError) as exc_info:
+ import_string("unittest.mock.NonExistentClass")
+ assert "does not define" in str(exc_info.value)
- def test_import_class_malformed_path(self):
+ def test_import_string_malformed_path(self):
"""Test import with malformed path (no dots)."""
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory._import_class("invalidpath")
- assert "Cannot import repository class" in str(exc_info.value)
-
- def test_validate_repository_interface_success(self):
- """Test successful interface validation."""
-
- # Create a mock class that implements the required methods
- class MockRepository:
- def save(self):
- pass
-
- def get_by_id(self):
- pass
-
- # Create a mock interface class
- class MockInterface:
- def save(self):
- pass
-
- def get_by_id(self):
- pass
-
- # Should not raise an exception when all methods are present
- DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
-
- def test_validate_repository_interface_missing_methods(self):
- """Test interface validation with missing methods."""
-
- # Create a mock class that's missing required methods
- class IncompleteRepository:
- def save(self):
- pass
-
- # Missing get_by_id method
-
- # Create a mock interface that requires both methods
- class MockInterface:
- def save(self):
- pass
-
- def get_by_id(self):
- pass
-
- def missing_method(self):
- pass
-
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
- assert "does not implement required methods" in str(exc_info.value)
-
- def test_validate_repository_interface_with_private_methods(self):
- """Test that private methods are ignored during interface validation."""
-
- class MockRepository:
- def save(self):
- pass
-
- def _private_method(self):
- pass
-
- # Create a mock interface with private methods
- class MockInterface:
- def save(self):
- pass
-
- def _private_method(self):
- pass
-
- # Should not raise exception - private methods should be ignored
- DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
+ with pytest.raises(ImportError) as exc_info:
+ import_string("invalidpath")
+ assert "doesn't look like a module path" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config):
@@ -133,11 +65,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
- # Mock the validation methods
- with (
- patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
- patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
- ):
+ # Mock import_string
+ with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@@ -170,34 +99,7 @@ class TestRepositoryFactory:
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
- assert "Cannot import repository class" in str(exc_info.value)
-
- @patch("core.repositories.factory.dify_config")
- def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
- """Test WorkflowExecutionRepository creation with validation error."""
- # Setup mock configuration
- mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
-
- mock_session_factory = MagicMock(spec=sessionmaker)
- mock_user = MagicMock(spec=Account)
-
- # Mock the import to succeed but validation to fail
- mock_repository_class = MagicMock()
- mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
- mocker.patch.object(
- DifyCoreRepositoryFactory,
- "_validate_repository_interface",
- side_effect=RepositoryImportError("Interface validation failed"),
- )
-
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory.create_workflow_execution_repository(
- session_factory=mock_session_factory,
- user=mock_user,
- app_id="test-app-id",
- triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
- )
- assert "Interface validation failed" in str(exc_info.value)
+ assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
@@ -212,11 +114,8 @@ class TestRepositoryFactory:
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
- # Mock the validation methods to succeed
- with (
- patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
- patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
- ):
+ # Mock import_string to return a failing class
+ with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@@ -243,11 +142,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
- # Mock the validation methods
- with (
- patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
- patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
- ):
+ # Mock import_string
+ with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@@ -280,34 +176,7 @@ class TestRepositoryFactory:
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
- assert "Cannot import repository class" in str(exc_info.value)
-
- @patch("core.repositories.factory.dify_config")
- def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
- """Test WorkflowNodeExecutionRepository creation with validation error."""
- # Setup mock configuration
- mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
-
- mock_session_factory = MagicMock(spec=sessionmaker)
- mock_user = MagicMock(spec=EndUser)
-
- # Mock the import to succeed but validation to fail
- mock_repository_class = MagicMock()
- mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
- mocker.patch.object(
- DifyCoreRepositoryFactory,
- "_validate_repository_interface",
- side_effect=RepositoryImportError("Interface validation failed"),
- )
-
- with pytest.raises(RepositoryImportError) as exc_info:
- DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
- session_factory=mock_session_factory,
- user=mock_user,
- app_id="test-app-id",
- triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
- )
- assert "Interface validation failed" in str(exc_info.value)
+ assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
@@ -322,11 +191,8 @@ class TestRepositoryFactory:
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
- # Mock the validation methods to succeed
- with (
- patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
- patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
- ):
+ # Mock import_string to return a failing class
+ with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@@ -359,11 +225,8 @@ class TestRepositoryFactory:
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
- # Mock the validation methods
- with (
- patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
- patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
- ):
+ # Mock import_string
+ with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py
new file mode 100644
index 0000000000..6425ab0b8d
--- /dev/null
+++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py
@@ -0,0 +1,181 @@
+import copy
+from unittest.mock import patch
+
+import pytest
+
+from core.entities.provider_entities import BasicProviderConfig
+from core.tools.utils.encryption import ProviderConfigEncrypter
+
+
+# ---------------------------
+# A no-op cache
+# ---------------------------
+class NoopCache:
+ """Simple cache stub: always returns None, does nothing for set/delete."""
+
+ def get(self):
+ return None
+
+ def set(self, config):
+ pass
+
+ def delete(self):
+ pass
+
+
+@pytest.fixture
+def secret_field() -> BasicProviderConfig:
+ """A SECRET_INPUT field named 'password'."""
+ return BasicProviderConfig(
+ name="password",
+ type=BasicProviderConfig.Type.SECRET_INPUT,
+ )
+
+
+@pytest.fixture
+def normal_field() -> BasicProviderConfig:
+ """A TEXT_INPUT field named 'username'."""
+ return BasicProviderConfig(
+ name="username",
+ type=BasicProviderConfig.Type.TEXT_INPUT,
+ )
+
+
+@pytest.fixture
+def encrypter_obj(secret_field, normal_field):
+ """
+ Build ProviderConfigEncrypter with:
+ - tenant_id = tenant123
+ - one secret field (password) and one normal field (username)
+ - NoopCache as cache
+ """
+ return ProviderConfigEncrypter(
+ tenant_id="tenant123",
+ config=[secret_field, normal_field],
+ provider_config_cache=NoopCache(),
+ )
+
+
+# ============================================================
+# ProviderConfigEncrypter.encrypt()
+# ============================================================
+
+
+def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj):
+ """
+ Secret field should be encrypted, non-secret field unchanged.
+ Verify encrypt_token called only for secret field.
+ Also check deep copy (input not modified).
+ """
+ data_in = {"username": "alice", "password": "plain_pwd"}
+ data_copy = copy.deepcopy(data_in)
+
+ with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
+ out = encrypter_obj.encrypt(data_in)
+
+ assert out["username"] == "alice"
+ assert out["password"] == "CIPHERTEXT"
+ mock_encrypt.assert_called_once_with("tenant123", "plain_pwd")
+ assert data_in == data_copy # deep copy semantics
+
+
+def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
+ """If secret field missing in input, no error and no encryption called."""
+ with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt:
+ out = encrypter_obj.encrypt({"username": "alice"})
+ assert out["username"] == "alice"
+ mock_encrypt.assert_not_called()
+
+
+# ============================================================
+# ProviderConfigEncrypter.mask_tool_credentials()
+# ============================================================
+
+
+@pytest.mark.parametrize(
+ ("raw", "prefix", "suffix"),
+ [
+ ("longsecret", "lo", "et"),
+ ("abcdefg", "ab", "fg"),
+ ("1234567", "12", "67"),
+ ],
+)
+def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix):
+ """
+ For length > 6: keep first 2 and last 2, mask middle with '*'.
+ """
+ data_in = {"username": "alice", "password": raw}
+ data_copy = copy.deepcopy(data_in)
+
+ out = encrypter_obj.mask_tool_credentials(data_in)
+ masked = out["password"]
+
+ assert masked.startswith(prefix)
+ assert masked.endswith(suffix)
+ assert "*" in masked
+ assert len(masked) == len(raw)
+ assert data_in == data_copy # deep copy semantics
+
+
+@pytest.mark.parametrize("raw", ["", "1", "12", "123", "123456"])
+def test_mask_tool_credentials_short_secret(encrypter_obj, raw):
+ """
+ For length <= 6: fully mask with '*' of same length.
+ """
+ out = encrypter_obj.mask_tool_credentials({"password": raw})
+ assert out["password"] == ("*" * len(raw))
+
+
+def test_mask_tool_credentials_missing_key_noop(encrypter_obj):
+ """If secret key missing, leave other fields unchanged."""
+ data_in = {"username": "alice"}
+ data_copy = copy.deepcopy(data_in)
+
+ out = encrypter_obj.mask_tool_credentials(data_in)
+ assert out["username"] == "alice"
+ assert data_in == data_copy
+
+
+# ============================================================
+# ProviderConfigEncrypter.decrypt()
+# ============================================================
+
+
+def test_decrypt_normal_flow(encrypter_obj):
+ """
+ Normal decrypt flow:
+ - decrypt_token called for secret field
+ - secret replaced with decrypted value
+ - non-secret unchanged
+ """
+ data_in = {"username": "alice", "password": "ENC"}
+ data_copy = copy.deepcopy(data_in)
+
+ with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
+ out = encrypter_obj.decrypt(data_in)
+
+ assert out["username"] == "alice"
+ assert out["password"] == "PLAIN"
+ mock_decrypt.assert_called_once_with("tenant123", "ENC")
+ assert data_in == data_copy # deep copy semantics
+
+
+@pytest.mark.parametrize("empty_val", ["", None])
+def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
+ """Skip decrypt if value is empty or None, keep original."""
+ with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt:
+ out = encrypter_obj.decrypt({"password": empty_val})
+
+ mock_decrypt.assert_not_called()
+ assert out["password"] == empty_val
+
+
+def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
+ """
+ If decrypt_token raises, exception should be swallowed,
+ and original value preserved.
+ """
+ with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
+ out = encrypter_obj.decrypt({"password": "ENC_ERR"})
+
+ assert out["password"] == "ENC_ERR"
diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py
index c17308baad..20f753786d 100644
--- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py
+++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py
@@ -1,6 +1,293 @@
-from core.tools.utils.web_reader_tool import get_image_upload_file_ids
+import pytest
+
+from core.tools.utils.web_reader_tool import (
+ extract_using_readabilipy,
+ get_image_upload_file_ids,
+ get_url,
+ page_result,
+)
+class FakeResponse:
+ """Minimal fake response object for ssrf_proxy / cloudscraper."""
+
+ def __init__(self, *, status_code=200, headers=None, content=b"", text=""):
+ self.status_code = status_code
+ self.headers = headers or {}
+ self.content = content
+ self.text = text if text else content.decode("utf-8", errors="ignore")
+
+
+# ---------------------------
+# Tests: page_result
+# ---------------------------
+@pytest.mark.parametrize(
+ ("text", "cursor", "maxlen", "expected"),
+ [
+ ("abcdef", 0, 3, "abc"),
+ ("abcdef", 2, 10, "cdef"), # maxlen beyond end
+ ("abcdef", 6, 5, ""), # cursor at end
+ ("abcdef", 7, 5, ""), # cursor beyond end
+ ("", 0, 5, ""), # empty text
+ ],
+)
+def test_page_result(text, cursor, maxlen, expected):
+ assert page_result(text, cursor, maxlen) == expected
+
+
+# ---------------------------
+# Tests: get_url
+# ---------------------------
+@pytest.fixture
+def stub_support_types(monkeypatch):
+ """Stub supported content types list."""
+ import core.tools.utils.web_reader_tool as mod
+
+ # e.g. binary types supported by ExtractProcessor
+ monkeypatch.setattr(mod.extract_processor, "SUPPORT_URL_CONTENT_TYPES", ["application/pdf", "text/plain"])
+ return mod
+
+
+def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
+ # HEAD 200 but content-type not supported and not text/html
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(
+ status_code=200,
+ headers={"Content-Type": "image/png"}, # not supported
+ )
+
+ monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
+
+ result = get_url("https://x.test/file.png")
+ assert result == "Unsupported content-type [image/png] of URL."
+
+
+def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types):
+ """
+ When content-type is in SUPPORT_URL_CONTENT_TYPES,
+ should call ExtractProcessor.load_from_url and return its text.
+ """
+ calls = {"load": 0}
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(
+ status_code=200,
+ headers={"Content-Type": "application/pdf"},
+ )
+
+ def fake_load_from_url(url, return_text=False):
+ calls["load"] += 1
+ assert return_text is True
+ return "PDF extracted text"
+
+ monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
+
+ result = get_url("https://x.test/doc.pdf")
+ assert calls["load"] == 1
+ assert result == "PDF extracted text"
+
+
+def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types):
+ """200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
+
+ def fake_get(url, headers=None, follow_redirects=True, timeout=None):
+ html = b"xhello"
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
+
+ # chardet.detect returns utf-8
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
+ monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
+
+ # readability → a dict that maps to Article, then FULL_TEMPLATE
+ def fake_simple_json_from_html_string(html, use_readability=True):
+ return {
+ "title": "My Title",
+ "byline": "Bob",
+ "plain_text": [{"type": "text", "text": "Hello world"}],
+ }
+
+ monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
+
+ out = get_url("https://x.test/page")
+ assert "TITLE: My Title" in out
+ assert "AUTHOR: Bob" in out
+ assert "Hello world" in out
+
+
+def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types):
+ """If readability returns no text, should return empty string."""
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
+
+ def fake_get(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=b"")
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
+ monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
+ # readability returns empty plain_text
+ monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []})
+
+ out = get_url("https://x.test/empty")
+ assert out == ""
+
+
+def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
+ """HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=403, headers={})
+
+ # cloudscraper.create_scraper() → object with .get()
+ class FakeScraper:
+ def __init__(self):
+ pass # removed unused attribute
+
+ def get(self, url, headers=None, follow_redirects=True, timeout=None):
+ # mimic html 200
+ html = b"hi"
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
+ monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
+ monkeypatch.setattr(
+ mod,
+ "simple_json_from_html_string",
+ lambda html, use_readability=True: {"title": "T", "byline": "A", "plain_text": [{"type": "text", "text": "X"}]},
+ )
+
+ out = get_url("https://x.test/403")
+ assert "TITLE: T" in out
+ assert "AUTHOR: A" in out
+ assert "X" in out
+
+
+def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
+ """HEAD returns non-200 and non-403 → should directly return code message."""
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=500)
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+
+ out = get_url("https://x.test/fail")
+ assert out == "URL returned status code 500."
+
+
+def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types):
+ """
+ If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
+ it should route to ExtractProcessor.load_from_url.
+ """
+ calls = {"load": 0}
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=200, headers={"Content-Disposition": 'attachment; filename="doc.pdf"'})
+
+ def fake_load_from_url(url, return_text=False):
+ calls["load"] += 1
+ return "From ExtractProcessor via filename"
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
+
+ out = get_url("https://x.test/fname")
+ assert calls["load"] == 1
+ assert out == "From ExtractProcessor via filename"
+
+
+def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types):
+ """
+ If chardet returns an encoding but content.decode raises, should fallback to response.text.
+ """
+
+ def fake_head(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
+
+ # Return bytes that will raise with the chosen encoding
+ def fake_get(url, headers=None, follow_redirects=True, timeout=None):
+ return FakeResponse(
+ status_code=200,
+ headers={"Content-Type": "text/html"},
+ content=b"\xff\xfe\xfa", # likely to fail under utf-8
+ text="fallback text",
+ )
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
+ monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
+ monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
+ monkeypatch.setattr(
+ mod,
+ "simple_json_from_html_string",
+ lambda html, use_readability=True: {"title": "", "byline": "", "plain_text": [{"type": "text", "text": "ok"}]},
+ )
+
+ out = get_url("https://x.test/enc-fallback")
+ assert "ok" in out
+
+
+# ---------------------------
+# Tests: extract_using_readabilipy
+# ---------------------------
+
+
+def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
+ # stub readabilipy.simple_json_from_html_string
+ def fake_simple_json_from_html_string(html, use_readability=True):
+ return {
+ "title": "Hello",
+ "byline": "Alice",
+ "plain_text": [{"type": "text", "text": "world"}],
+ }
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
+
+ article = extract_using_readabilipy("...")
+ assert article.title == "Hello"
+ assert article.author == "Alice"
+ assert isinstance(article.text, list)
+ assert article.text
+ assert article.text[0]["text"] == "world"
+
+
+def test_extract_using_readabilipy_defaults_when_missing(monkeypatch):
+ def fake_simple_json_from_html_string(html, use_readability=True):
+ return {} # all missing
+
+ import core.tools.utils.web_reader_tool as mod
+
+ monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
+
+ article = extract_using_readabilipy("...")
+ assert article.title == ""
+ assert article.author == ""
+ assert article.text == []
+
+
+# ---------------------------
+# Tests: get_image_upload_file_ids
+# ---------------------------
def test_get_image_upload_file_ids():
# should extract id from https + file-preview
content = ""
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
index 137e8b889d..8b1b9a55bc 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
@@ -1,6 +1,5 @@
import uuid
from collections.abc import Generator
-from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
@@ -15,6 +14,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
+from libs.datetime_utils import naive_utc_now
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
- route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
+ route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now())
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
@@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
)
route_node_state.status = RouteNodeState.Status.SUCCESS
- route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None)
+ route_node_state.finished_at = naive_utc_now()
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,
diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
index 4866db1fdb..1d2eba1e71 100644
--- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
+++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
@@ -1,5 +1,4 @@
import json
-from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
@@ -23,6 +22,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
+from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import Workflow, WorkflowRun
@@ -145,8 +145,8 @@ def real_workflow():
workflow.graph = json.dumps(graph_data)
workflow.features = json.dumps({"file_upload": {"enabled": False}})
workflow.created_by = "test-user-id"
- workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
- workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.created_at = naive_utc_now()
+ workflow.updated_at = naive_utc_now()
workflow._environment_variables = "{}"
workflow._conversation_variables = "{}"
@@ -169,7 +169,7 @@ def real_workflow_run():
workflow_run.outputs = json.dumps({"answer": "test answer"})
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id"
- workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow_run.created_at = naive_utc_now()
return workflow_run
@@ -211,7 +211,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@@ -245,7 +245,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@@ -282,7 +282,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@@ -335,7 +335,7 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@@ -366,7 +366,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
- event.start_at = datetime.now(UTC).replace(tzinfo=None)
+ event.start_at = naive_utc_now()
# Create a real node execution
@@ -379,7 +379,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution
@@ -409,7 +409,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
@@ -443,7 +443,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
- event.start_at = datetime.now(UTC).replace(tzinfo=None)
+ event.start_at = naive_utc_now()
event.error = "Test error message"
# Create a real node execution
@@ -457,7 +457,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
- created_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution
diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py
index 5bc77ad0ef..4c61320c29 100644
--- a/api/tests/unit_tests/models/test_workflow.py
+++ b/api/tests/unit_tests/models/test_workflow.py
@@ -9,7 +9,6 @@ from core.file.models import File
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from core.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment
-from models.model import EndUser
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
@@ -43,14 +42,9 @@ def test_environment_variables():
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
- # Mock current_user as an EndUser
- mock_user = mock.Mock(spec=EndUser)
- mock_user.tenant_id = "tenant_id"
-
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
- mock.patch("models.workflow.current_user", mock_user),
):
# Set the environment_variables property of the Workflow instance
variables = [variable1, variable2, variable3, variable4]
@@ -90,14 +84,9 @@ def test_update_environment_variables():
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
- # Mock current_user as an EndUser
- mock_user = mock.Mock(spec=EndUser)
- mock_user.tenant_id = "tenant_id"
-
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
- mock.patch("models.workflow.current_user", mock_user),
):
variables = [variable1, variable2, variable3, variable4]
@@ -136,14 +125,9 @@ def test_to_dict():
# Create some EnvironmentVariable instances
- # Mock current_user as an EndUser
- mock_user = mock.Mock(spec=EndUser)
- mock_user.tenant_id = "tenant_id"
-
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
- mock.patch("models.workflow.current_user", mock_user),
):
# Set the environment_variables property of the Workflow instance
workflow.environment_variables = [
diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py
index dc09aca5b2..1881ceac26 100644
--- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py
+++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py
@@ -93,16 +93,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
with (
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
patch("extensions.ext_database.db.session") as mock_db,
- patch("services.dataset_service.datetime") as mock_datetime,
+ patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
- mock_datetime.datetime.now.return_value = current_time
- mock_datetime.UTC = datetime.UTC
+ mock_naive_utc_now.return_value = current_time
yield {
"get_document": mock_get_doc,
"db_session": mock_db,
- "datetime": mock_datetime,
+ "naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
@@ -120,21 +119,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert document.enabled == True
assert document.disabled_at is None
assert document.disabled_by is None
- assert document.updated_at == current_time.replace(tzinfo=None)
+ assert document.updated_at == current_time
def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was disabled correctly."""
assert document.enabled == False
- assert document.disabled_at == current_time.replace(tzinfo=None)
+ assert document.disabled_at == current_time
assert document.disabled_by == user_id
- assert document.updated_at == current_time.replace(tzinfo=None)
+ assert document.updated_at == current_time
def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was archived correctly."""
assert document.archived == True
- assert document.archived_at == current_time.replace(tzinfo=None)
+ assert document.archived_at == current_time
assert document.archived_by == user_id
- assert document.updated_at == current_time.replace(tzinfo=None)
+ assert document.updated_at == current_time
def _assert_document_unarchived(self, document: Mock):
"""Helper method to verify document was unarchived correctly."""
@@ -430,7 +429,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Verify document attributes were updated correctly
self._assert_document_unarchived(archived_doc)
- assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None)
+ assert archived_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify Redis cache was set (because document is enabled)
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
@@ -495,9 +494,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Verify document was unarchived
self._assert_document_unarchived(archived_disabled_doc)
- assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace(
- tzinfo=None
- )
+ assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify no Redis cache was set (document is disabled)
redis_mock.setex.assert_not_called()
diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py
index c4c7579e83..0fc36510b9 100644
--- a/api/tests/unit_tests/services/test_metadata_bug_complete.py
+++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py
@@ -1,7 +1,7 @@
from unittest.mock import Mock, patch
import pytest
-from flask_restful import reqparse
+from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py
index ef4d05c1d9..7f6344f942 100644
--- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py
+++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py
@@ -1,7 +1,7 @@
from unittest.mock import Mock, patch
import pytest
-from flask_restful import reqparse
+from flask_restx import reqparse
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
diff --git a/api/uv.lock b/api/uv.lock
index faf87fa698..45b020e1dd 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -741,6 +741,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" },
]
+[[package]]
+name = "celery-types"
+version = "0.23.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e9/d1/0823e71c281e4ad0044e278cf1577d1a68e05f2809424bf94e1614925c5d/celery_types-0.23.0.tar.gz", hash = "sha256:402ed0555aea3cd5e1e6248f4632e4f18eec8edb2435173f9e6dc08449fa101e", size = 31479, upload-time = "2025-03-03T23:56:51.547Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6f/8b/92bb54dd74d145221c3854aa245c84f4dc04cc9366147496182cec8e88e3/celery_types-0.23.0-py3-none-any.whl", hash = "sha256:0cc495b8d7729891b7e070d0ec8d4906d2373209656a6e8b8276fe1ed306af9a", size = 50189, upload-time = "2025-03-03T23:56:50.458Z" },
+]
+
[[package]]
name = "certifi"
version = "2025.6.15"
@@ -1254,7 +1266,7 @@ dependencies = [
{ name = "flask-login" },
{ name = "flask-migrate" },
{ name = "flask-orjson" },
- { name = "flask-restful" },
+ { name = "flask-restx" },
{ name = "flask-sqlalchemy" },
{ name = "gevent" },
{ name = "gmpy2" },
@@ -1326,6 +1338,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
{ name = "boto3-stubs" },
+ { name = "celery-types" },
{ name = "coverage" },
{ name = "dotenv-linter" },
{ name = "faker" },
@@ -1436,13 +1449,13 @@ requires-dist = [
{ name = "cachetools", specifier = "~=5.3.0" },
{ name = "celery", specifier = "~=5.5.2" },
{ name = "chardet", specifier = "~=5.1.0" },
- { name = "flask", specifier = "~=3.1.0" },
+ { name = "flask", specifier = "~=3.1.2" },
{ name = "flask-compress", specifier = "~=1.17" },
{ name = "flask-cors", specifier = "~=6.0.0" },
{ name = "flask-login", specifier = "~=0.6.3" },
{ name = "flask-migrate", specifier = "~=4.0.7" },
{ name = "flask-orjson", specifier = "~=2.0.0" },
- { name = "flask-restful", specifier = "~=0.3.10" },
+ { name = "flask-restx", specifier = ">=1.3.0" },
{ name = "flask-sqlalchemy", specifier = "~=3.1.1" },
{ name = "gevent", specifier = "~=24.11.1" },
{ name = "gmpy2", specifier = "~=2.2.1" },
@@ -1514,12 +1527,13 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "boto3-stubs", specifier = ">=1.38.20" },
+ { name = "celery-types", specifier = ">=0.23.0" },
{ name = "coverage", specifier = "~=7.2.4" },
{ name = "dotenv-linter", specifier = "~=0.5.0" },
{ name = "faker", specifier = "~=32.1.0" },
{ name = "hypothesis", specifier = ">=6.131.15" },
{ name = "lxml-stubs", specifier = "~=0.5.1" },
- { name = "mypy", specifier = "~=1.16.0" },
+ { name = "mypy", specifier = "~=1.17.1" },
{ name = "pandas-stubs", specifier = "~=2.2.3" },
{ name = "pytest", specifier = "~=8.3.2" },
{ name = "pytest-benchmark", specifier = "~=4.0.0" },
@@ -1790,7 +1804,7 @@ wheels = [
[[package]]
name = "flask"
-version = "3.1.1"
+version = "3.1.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "blinker" },
@@ -1800,9 +1814,9 @@ dependencies = [
{ name = "markupsafe" },
{ name = "werkzeug" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/c0/de/e47735752347f4128bcf354e0da07ef311a78244eba9e3dc1d4a5ab21a98/flask-3.1.1.tar.gz", hash = "sha256:284c7b8f2f58cb737f0cf1c30fd7eaf0ccfcde196099d24ecede3fc2005aa59e", size = 753440, upload-time = "2025-05-13T15:01:17.447Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/dc/6d/cfe3c0fcc5e477df242b98bfe186a4c34357b4847e87ecaef04507332dab/flask-3.1.2.tar.gz", hash = "sha256:bf656c15c80190ed628ad08cdfd3aaa35beb087855e2f494910aa3774cc4fd87", size = 720160, upload-time = "2025-08-19T21:03:21.205Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/3d/68/9d4508e893976286d2ead7f8f571314af6c2037af34853a30fd769c02e9d/flask-3.1.1-py3-none-any.whl", hash = "sha256:07aae2bb5eaf77993ef57e357491839f5fd9f4dc281593a81a9e4d79a24f295c", size = 103305, upload-time = "2025-05-13T15:01:15.591Z" },
+ { url = "https://files.pythonhosted.org/packages/ec/f9/7f9263c5695f4bd0023734af91bedb2ff8209e8de6ead162f35d8dc762fd/flask-3.1.2-py3-none-any.whl", hash = "sha256:ca1d8112ec8a6158cc29ea4858963350011b5c846a414cdb7a954aa9e967d03c", size = 103308, upload-time = "2025-08-19T21:03:19.499Z" },
]
[[package]]
@@ -1875,18 +1889,20 @@ wheels = [
]
[[package]]
-name = "flask-restful"
-version = "0.3.10"
+name = "flask-restx"
+version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aniso8601" },
{ name = "flask" },
+ { name = "importlib-resources" },
+ { name = "jsonschema" },
{ name = "pytz" },
- { name = "six" },
+ { name = "werkzeug" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/c0/ce/a0a133db616ea47f78a41e15c4c68b9f08cab3df31eb960f61899200a119/Flask-RESTful-0.3.10.tar.gz", hash = "sha256:fe4af2ef0027df8f9b4f797aba20c5566801b6ade995ac63b588abf1a59cec37", size = 110453, upload-time = "2023-05-21T03:58:55.781Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/45/4c/2e7d84e2b406b47cf3bf730f521efe474977b404ee170d8ea68dc37e6733/flask-restx-1.3.0.tar.gz", hash = "sha256:4f3d3fa7b6191fcc715b18c201a12cd875176f92ba4acc61626ccfd571ee1728", size = 2814072, upload-time = "2023-12-10T14:48:55.575Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/d7/7b/f0b45f0df7d2978e5ae51804bb5939b7897b2ace24306009da0cc34d8d1f/Flask_RESTful-0.3.10-py2.py3-none-any.whl", hash = "sha256:1cf93c535172f112e080b0d4503a8d15f93a48c88bdd36dd87269bdaf405051b", size = 26217, upload-time = "2023-05-21T03:58:54.004Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/bf/1907369f2a7ee614dde5152ff8f811159d357e77962aa3f8c2e937f63731/flask_restx-1.3.0-py2.py3-none-any.whl", hash = "sha256:636c56c3fb3f2c1df979e748019f084a938c4da2035a3e535a4673e4fc177691", size = 2798683, upload-time = "2023-12-10T14:48:53.293Z" },
]
[[package]]
@@ -3272,28 +3288,28 @@ wheels = [
[[package]]
name = "mypy"
-version = "1.16.1"
+version = "1.17.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy-extensions" },
{ name = "pathspec" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/81/69/92c7fa98112e4d9eb075a239caa4ef4649ad7d441545ccffbd5e34607cbb/mypy-1.16.1.tar.gz", hash = "sha256:6bd00a0a2094841c5e47e7374bb42b83d64c527a502e3334e1173a0c24437bab", size = 3324747, upload-time = "2025-06-16T16:51:35.145Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/9a/61/ec1245aa1c325cb7a6c0f8570a2eee3bfc40fa90d19b1267f8e50b5c8645/mypy-1.16.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:472e4e4c100062488ec643f6162dd0d5208e33e2f34544e1fc931372e806c0cc", size = 10890557, upload-time = "2025-06-16T16:37:21.421Z" },
- { url = "https://files.pythonhosted.org/packages/6b/bb/6eccc0ba0aa0c7a87df24e73f0ad34170514abd8162eb0c75fd7128171fb/mypy-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea16e2a7d2714277e349e24d19a782a663a34ed60864006e8585db08f8ad1782", size = 10012921, upload-time = "2025-06-16T16:51:28.659Z" },
- { url = "https://files.pythonhosted.org/packages/5f/80/b337a12e2006715f99f529e732c5f6a8c143bb58c92bb142d5ab380963a5/mypy-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08e850ea22adc4d8a4014651575567b0318ede51e8e9fe7a68f25391af699507", size = 11802887, upload-time = "2025-06-16T16:50:53.627Z" },
- { url = "https://files.pythonhosted.org/packages/d9/59/f7af072d09793d581a745a25737c7c0a945760036b16aeb620f658a017af/mypy-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22d76a63a42619bfb90122889b903519149879ddbf2ba4251834727944c8baca", size = 12531658, upload-time = "2025-06-16T16:33:55.002Z" },
- { url = "https://files.pythonhosted.org/packages/82/c4/607672f2d6c0254b94a646cfc45ad589dd71b04aa1f3d642b840f7cce06c/mypy-1.16.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c7ce0662b6b9dc8f4ed86eb7a5d505ee3298c04b40ec13b30e572c0e5ae17c4", size = 12732486, upload-time = "2025-06-16T16:37:03.301Z" },
- { url = "https://files.pythonhosted.org/packages/b6/5e/136555ec1d80df877a707cebf9081bd3a9f397dedc1ab9750518d87489ec/mypy-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:211287e98e05352a2e1d4e8759c5490925a7c784ddc84207f4714822f8cf99b6", size = 9479482, upload-time = "2025-06-16T16:47:37.48Z" },
- { url = "https://files.pythonhosted.org/packages/b4/d6/39482e5fcc724c15bf6280ff5806548c7185e0c090712a3736ed4d07e8b7/mypy-1.16.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:af4792433f09575d9eeca5c63d7d90ca4aeceda9d8355e136f80f8967639183d", size = 11066493, upload-time = "2025-06-16T16:47:01.683Z" },
- { url = "https://files.pythonhosted.org/packages/e6/e5/26c347890efc6b757f4d5bb83f4a0cf5958b8cf49c938ac99b8b72b420a6/mypy-1.16.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:66df38405fd8466ce3517eda1f6640611a0b8e70895e2a9462d1d4323c5eb4b9", size = 10081687, upload-time = "2025-06-16T16:48:19.367Z" },
- { url = "https://files.pythonhosted.org/packages/44/c7/b5cb264c97b86914487d6a24bd8688c0172e37ec0f43e93b9691cae9468b/mypy-1.16.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44e7acddb3c48bd2713994d098729494117803616e116032af192871aed80b79", size = 11839723, upload-time = "2025-06-16T16:49:20.912Z" },
- { url = "https://files.pythonhosted.org/packages/15/f8/491997a9b8a554204f834ed4816bda813aefda31cf873bb099deee3c9a99/mypy-1.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ab5eca37b50188163fa7c1b73c685ac66c4e9bdee4a85c9adac0e91d8895e15", size = 12722980, upload-time = "2025-06-16T16:37:40.929Z" },
- { url = "https://files.pythonhosted.org/packages/df/f0/2bd41e174b5fd93bc9de9a28e4fb673113633b8a7f3a607fa4a73595e468/mypy-1.16.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dedb6229b2c9086247e21a83c309754b9058b438704ad2f6807f0d8227f6ebdd", size = 12903328, upload-time = "2025-06-16T16:34:35.099Z" },
- { url = "https://files.pythonhosted.org/packages/61/81/5572108a7bec2c46b8aff7e9b524f371fe6ab5efb534d38d6b37b5490da8/mypy-1.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:1f0435cf920e287ff68af3d10a118a73f212deb2ce087619eb4e648116d1fe9b", size = 9562321, upload-time = "2025-06-16T16:48:58.823Z" },
- { url = "https://files.pythonhosted.org/packages/cf/d3/53e684e78e07c1a2bf7105715e5edd09ce951fc3f47cf9ed095ec1b7a037/mypy-1.16.1-py3-none-any.whl", hash = "sha256:5fc2ac4027d0ef28d6ba69a0343737a23c4d1b83672bf38d1fe237bdc0643b37", size = 2265923, upload-time = "2025-06-16T16:48:02.366Z" },
+ { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" },
+ { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" },
+ { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" },
+ { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" },
+ { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" },
+ { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" },
+ { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" },
+ { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" },
+ { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" },
+ { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" },
+ { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" },
+ { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" },
+ { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" },
]
[[package]]
diff --git a/docker/.env.example b/docker/.env.example
index 826b7b9fe6..711898016e 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -215,6 +215,8 @@ DB_DATABASE=dify
# The size of the database connection pool.
# The default is 30 connections, which can be appropriately increased.
SQLALCHEMY_POOL_SIZE=30
+# The default is 10 connections, which allows temporary overflow beyond the pool size.
+SQLALCHEMY_MAX_OVERFLOW=10
# Database connection pool recycling time, the default is 3600 seconds.
SQLALCHEMY_POOL_RECYCLE=3600
# Whether to print SQL, default is false.
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 0c352e4658..d3b75d93af 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -57,6 +57,7 @@ x-shared-env: &shared-api-worker-env
DB_PORT: ${DB_PORT:-5432}
DB_DATABASE: ${DB_DATABASE:-dify}
SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30}
+ SQLALCHEMY_MAX_OVERFLOW: ${SQLALCHEMY_MAX_OVERFLOW:-10}
SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600}
SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false}
SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false}
diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx
index f1bb5d9156..0d41691dfd 100644
--- a/web/app/(commonLayout)/datasets/template/template.en.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.en.mdx
@@ -1858,10 +1858,10 @@ ___
title="Request"
tag="DELETE"
label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}"
- targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
+ targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
>
```bash {{ title: 'cURL' }}
- curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \
+ curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \
--header 'Authorization: Bearer {api_key}'
```
diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx
index 3011cecbc1..5c7a752c11 100644
--- a/web/app/(commonLayout)/datasets/template/template.ja.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx
@@ -1614,10 +1614,10 @@ ___
title="リクエスト"
tag="DELETE"
label="/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}"
- targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
+ targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
>
```bash {{ title: 'cURL' }}
- curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \
+ curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \
--header 'Authorization: Bearer {api_key}'
```
diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx
index cf32339b8a..c3ff45d6a6 100644
--- a/web/app/components/app-sidebar/index.tsx
+++ b/web/app/components/app-sidebar/index.tsx
@@ -107,7 +107,7 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati
)}