diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index c25bde87b0..39a653953e 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,6 +1,6 @@ #!/bin/bash -npm add -g pnpm@10.13.1 +npm add -g pnpm@10.15.0 cd web && pnpm install pipx install uv diff --git a/README.md b/README.md index 7e566a0b2f..90da1d3def 100644 --- a/README.md +++ b/README.md @@ -107,74 +107,6 @@ Monitor and analyze application logs and performance over time. You could contin **7. Backend-as-a-Service**: All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** diff --git a/README_AR.md b/README_AR.md index 044ced98ed..2451757ab5 100644 --- a/README_AR.md +++ b/README_AR.md @@ -68,74 +68,6 @@ **7.الواجهة الخلفية (Backend) كخدمة**: تأتي جميع عروض Dify مع APIs مطابقة، حتى يمكنك دمج Dify بسهولة في منطق أعمالك الخاص. -## مقارنة الميزات - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
الميزةDify.AILangChainFlowiseOpenAI Assistants API
نهج البرمجةموجّه لـ تطبيق + واجهة برمجة تطبيق (API)برمجة Pythonموجه لتطبيقواجهة برمجة تطبيق (API)
LLMs المدعومةتنوع غنيتنوع غنيتنوع غنيفقط OpenAI
محرك RAG
الوكيل
سير العمل
الملاحظة
ميزات الشركات (SSO / مراقبة الوصول)
نشر محلي
- ## استخدام Dify - **سحابة
** diff --git a/README_BN.md b/README_BN.md index f5a19ab434..ef24dea171 100644 --- a/README_BN.md +++ b/README_BN.md @@ -106,74 +106,6 @@ LLM ফাংশন কলিং বা ReAct উপর ভিত্তি ক **7. ব্যাকএন্ড-অ্যাজ-এ-সার্ভিস**: ডিফাই-এর সমস্ত অফার সংশ্লিষ্ট API-সহ আছে, যাতে আপনি অনায়াসে ডিফাইকে আপনার নিজস্ব বিজনেস লজিকে ইন্টেগ্রেট করতে পারেন। -## বৈশিষ্ট্য তুলনা - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
বৈশিষ্ট্যDify.AILangChainFlowiseOpenAI Assistants API
প্রোগ্রামিং পদ্ধতিAPI + App-orientedPython CodeApp-orientedAPI-oriented
সাপোর্টেড LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG ইঞ্জিন
এজেন্ট
ওয়ার্কফ্লো
অবজার্ভেবল
এন্টারপ্রাইজ ফিচার (SSO/Access control)
লোকাল ডেপ্লয়মেন্ট
- ## ডিফাই-এর ব্যবহার - **ক্লাউড
** diff --git a/README_CN.md b/README_CN.md index 1c40098034..2949b38867 100644 --- a/README_CN.md +++ b/README_CN.md @@ -80,74 +80,6 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI **7. 后端即服务**: 所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。 -## 功能比较 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
功能Dify.AILangChainFlowiseOpenAI Assistant API
编程方法API + 应用程序导向Python 代码应用程序导向API 导向
支持的 LLMs丰富多样丰富多样丰富多样仅限 OpenAI
RAG 引擎
Agent
工作流
可观测性
企业功能(SSO/访问控制)
本地部署
- ## 使用 Dify - **云
** diff --git a/README_DE.md b/README_DE.md index 88c36019e3..a593a12abf 100644 --- a/README_DE.md +++ b/README_DE.md @@ -106,74 +106,6 @@ Sie können Agenten basierend auf LLM Function Calling oder ReAct definieren und **7. Backend-as-a-Service**: Alle Dify-Angebote kommen mit entsprechenden APIs, sodass Sie Dify mühelos in Ihre eigene Geschäftslogik integrieren können. -## Vergleich der Merkmale - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- ## Dify verwenden - **Cloud
** diff --git a/README_ES.md b/README_ES.md index bc3b25f2d1..c7a18dc675 100644 --- a/README_ES.md +++ b/README_ES.md @@ -79,74 +79,6 @@ Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiemp **7. Backend como servicio**: Todas las ofertas de Dify vienen con APIs correspondientes, por lo que podrías integrar Dify sin esfuerzo en tu propia lógica empresarial. -## Comparación de características - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
CaracterísticaDify.AILangChainFlowiseAPI de Asistentes de OpenAI
Enfoque de programaciónAPI + orientado a la aplicaciónCódigo PythonOrientado a la aplicaciónOrientado a la API
LLMs admitidosGran variedadGran variedadGran variedadSolo OpenAI
Motor RAG
Agente
Flujo de trabajo
Observabilidad
Característica empresarial (SSO/Control de acceso)
Implementación local
- ## Usando Dify - **Nube
** diff --git a/README_FR.md b/README_FR.md index 7521753100..316d50c929 100644 --- a/README_FR.md +++ b/README_FR.md @@ -79,74 +79,6 @@ Surveillez et analysez les journaux d'application et les performances au fil du **7. Backend-as-a-Service** : Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier. -## Comparaison des fonctionnalités - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FonctionnalitéDify.AILangChainFlowiseOpenAI Assistants API
Approche de programmationAPI + ApplicationCode PythonApplicationAPI
LLMs pris en chargeGrande variétéGrande variétéGrande variétéUniquement OpenAI
Moteur RAG
Agent
Flux de travail
Observabilité
Fonctionnalité d'entreprise (SSO/Contrôle d'accès)
Déploiement local
- ## Utiliser Dify - **Cloud
** diff --git a/README_JA.md b/README_JA.md index 3427a86b79..785706a88a 100644 --- a/README_JA.md +++ b/README_JA.md @@ -80,74 +80,6 @@ LLM Function CallingやReActに基づくエージェントの定義が可能で **7. Backend-as-a-Service**: すべての機能はAPIを提供されており、Difyを自分のビジネスロジックに簡単に統合できます。 -## 機能比較 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
機能Dify.AILangChainFlowiseOpenAI Assistants API
プログラミングアプローチAPI + アプリ指向Pythonコードアプリ指向API指向
サポートされているLLMバラエティ豊かバラエティ豊かバラエティ豊かOpenAIのみ
RAGエンジン
エージェント
ワークフロー
観測性
エンタープライズ機能(SSO/アクセス制御)
ローカル展開
- ## Difyの使用方法 - **クラウド
** diff --git a/README_KL.md b/README_KL.md index 252a2b6db5..93da9a6140 100644 --- a/README_KL.md +++ b/README_KL.md @@ -79,74 +79,6 @@ Monitor and analyze application logs and performance over time. You could contin **7. Backend-as-a-Service**: All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** diff --git a/README_KR.md b/README_KR.md index 278e3f6c33..3b58339e12 100644 --- a/README_KR.md +++ b/README_KR.md @@ -73,74 +73,6 @@ LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에 **7. Backend-as-a-Service**: Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다. -## 기능 비교 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
기능Dify.AILangChainFlowiseOpenAI Assistants API
프로그래밍 접근 방식API + 앱 중심Python 코드앱 중심API 중심
지원되는 LLMs다양한 종류다양한 종류다양한 종류OpenAI 전용
RAG 엔진
에이전트
워크플로우
가시성
기업용 기능 (SSO/접근 제어)
로컬 배포
- ## Dify 사용하기 - **클라우드
** diff --git a/README_PT.md b/README_PT.md index 8bff880728..ec2e4245f6 100644 --- a/README_PT.md +++ b/README_PT.md @@ -79,74 +79,6 @@ Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo. **7. Backend como Serviço**: Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa. -## Comparação de recursos - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
RecursoDify.AILangChainFlowiseOpenAI Assistants API
Abordagem de ProgramaçãoOrientada a API + AplicativoCódigo PythonOrientada a AplicativoOrientada a API
LLMs SuportadosVariedade RicaVariedade RicaVariedade RicaApenas OpenAI
RAG Engine
Agente
Workflow
Observabilidade
Recursos Empresariais (SSO/Controle de Acesso)
Implantação Local
- ## Usando o Dify - **Nuvem
** diff --git a/README_SI.md b/README_SI.md index be8c6320fb..c20dc3484f 100644 --- a/README_SI.md +++ b/README_SI.md @@ -103,74 +103,6 @@ Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozi **7. Backend-as-a-Service**: AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko. -## Primerjava Funkcij - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FunkcijaDify.AILangChainFlowiseOpenAI Assistants API
Programski pristopAPI + usmerjeno v aplikacijePython kodaUsmerjeno v aplikacijeUsmerjeno v API
Podprti LLM-jiBogata izbiraBogata izbiraBogata izbiraSamo OpenAI
RAG pogon
Agent
Potek dela
Spremljanje
Funkcija za podjetja (SSO/nadzor dostopa)
Lokalna namestitev
- ## Uporaba Dify - **Cloud
** diff --git a/README_TR.md b/README_TR.md index e54b1f4589..510b112e68 100644 --- a/README_TR.md +++ b/README_TR.md @@ -74,74 +74,6 @@ Uygulama loglarını ve performans metriklerini zaman içinde izleme ve analiz e **7. Hizmet Olarak Backend**: Dify'ın tüm özellikleri ilgili API'lerle birlikte gelir, böylece Dify'ı kendi iş mantığınıza kolayca entegre edebilirsiniz. -## Özellik karşılaştırması - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ÖzellikDify.AILangChainFlowiseOpenAI Assistants API
Programlama YaklaşımıAPI + Uygulama odaklıPython KoduUygulama odaklıAPI odaklı
Desteklenen LLM'lerZengin ÇeşitlilikZengin ÇeşitlilikZengin ÇeşitlilikYalnızca OpenAI
RAG Motoru
Ajan
İş Akışı
Gözlemlenebilirlik
Kurumsal Özellikler (SSO/Erişim kontrolü)
Yerel Dağıtım
- ## Dify'ı Kullanma - **Cloud
** diff --git a/README_TW.md b/README_TW.md index c41434771c..35a01fa16a 100644 --- a/README_TW.md +++ b/README_TW.md @@ -106,74 +106,6 @@ docker compose up -d **7. 後端即服務**: Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 整合到您自己的業務邏輯中。 -## 功能比較 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
功能Dify.AILangChainFlowiseOpenAI Assistants API
程式設計方法API + 應用導向Python 代碼應用導向API 導向
支援的 LLM 模型豐富多樣豐富多樣豐富多樣僅限 OpenAI
RAG 引擎
代理功能
工作流程
可觀察性
企業級功能 (SSO/存取控制)
本地部署
- ## 使用 Dify - **雲端服務
** diff --git a/README_VI.md b/README_VI.md index 8c5c333e8f..f161b20f9d 100644 --- a/README_VI.md +++ b/README_VI.md @@ -74,74 +74,6 @@ Giám sát và phân tích nhật ký và hiệu suất ứng dụng theo thời **7. Backend-as-a-Service**: Tất cả các dịch vụ của Dify đều đi kèm với các API tương ứng, vì vậy bạn có thể dễ dàng tích hợp Dify vào logic kinh doanh của riêng mình. -## So sánh tính năng - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Tính năngDify.AILangChainFlowiseOpenAI Assistants API
Phương pháp lập trìnhHướng API + Ứng dụngMã PythonHướng ứng dụngHướng API
LLMs được hỗ trợĐa dạng phong phúĐa dạng phong phúĐa dạng phong phúChỉ OpenAI
RAG Engine
Agent
Quy trình làm việc
Khả năng quan sát
Tính năng doanh nghiệp (SSO/Kiểm soát truy cập)
Triển khai cục bộ
- ## Sử dụng Dify - **Cloud
** diff --git a/api/README.md b/api/README.md index 5571fdd0fd..8309a0e69b 100644 --- a/api/README.md +++ b/api/README.md @@ -80,7 +80,7 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage +uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation ``` Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: @@ -97,8 +97,16 @@ uv run celery -A app.celery beat uv sync --dev ``` -1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` +1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md) - ```bash - uv run -P api bash dev/pytest/pytest_all_tests.sh + ```cli + uv run --project api pytest # Run all tests + uv run --project api pytest tests/unit_tests/ # Unit tests only + uv run --project api pytest tests/integration_tests/ # Integration tests + + # Code quality + ./dev/reformat # Run all formatters and linters + uv run --project api ruff check --fix ./ # Fix linting issues + uv run --project api ruff format ./ # Format code + uv run --project api mypy . # Type checking ``` diff --git a/api/app_factory.py b/api/app_factory.py index 032d6b17fc..8a0417dd72 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -5,6 +5,8 @@ from configs import dify_config from contexts.wrapper import RecyclableContextVar from dify_app import DifyApp +logger = logging.getLogger(__name__) + # ---------------------------- # Application Factory Function @@ -32,7 +34,7 @@ def create_app() -> DifyApp: initialize_extensions(app) end_time = time.perf_counter() if dify_config.DEBUG: - logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) + logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) return app @@ -93,14 +95,14 @@ def initialize_extensions(app: DifyApp): is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True if not is_enabled: if dify_config.DEBUG: - logging.info("Skipped %s", short_name) + logger.info("Skipped %s", short_name) continue start_time = time.perf_counter() ext.init_app(app) end_time = time.perf_counter() if dify_config.DEBUG: - logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) + logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) def create_migrations_app(): diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 3466eea1f6..df9de825de 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import AppIconUrlField @@ -10,6 +10,12 @@ parameters__system_parameters = { "workflow_file_upload_limit": fields.Integer, } + +def build_system_parameters_model(api_or_ns: Api | Namespace): + """Build the system parameters model for the API or Namespace.""" + return api_or_ns.model("SystemParameters", parameters__system_parameters) + + parameters_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -25,6 +31,14 @@ parameters_fields = { "system_parameters": fields.Nested(parameters__system_parameters), } + +def build_parameters_model(api_or_ns: Api | Namespace): + """Build the parameters model for the API or Namespace.""" + copied_fields = parameters_fields.copy() + copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) + return api_or_ns.model("Parameters", copied_fields) + + site_fields = { "title": fields.String, "chat_color_theme": fields.String, @@ -41,3 +55,8 @@ site_fields = { "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } + + +def build_site_model(api_or_ns: Api | Namespace): + """Build the site model for the API or Namespace.""" + return api_or_ns.model("Site", site_fields) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d373d7c72f..b94a9f4ee4 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -84,7 +84,6 @@ from .datasets import ( external, hit_testing, metadata, - upload_file, website, ) from .datasets.rag_pipeline import ( diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8a55197fb6..7e5c28200a 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 30c890c301..401e88709a 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,8 @@ from typing import Any, Optional -import flask_restful +import flask_restx from flask_login import current_user -from flask_restful import Resource, fields, marshal_with +from flask_restx import Resource, fields, marshal_with from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -40,7 +40,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restful.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -81,7 +81,7 @@ class BaseApiKeyListResource(Resource): ) if current_key_count >= self.max_keys: - flask_restful.abort( + flask_restx.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -126,7 +126,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restful.abort(404, message="API key not found") + flask_restx.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa5..c6cb6f6e3a 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d433415894..a964154207 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 2caa908d4a..37d23ccd9f 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -2,7 +2,7 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.common.errors import NoFileUploadedError, TooManyFilesError diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 1cc13d669c..a6eb86122d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,7 +2,7 @@ import uuid from typing import cast from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 9ffb94e9f9..aee93a8814 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 665cf1aede..ea1869a587 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index ad94112f05..bd5e7d0924 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -2,7 +2,7 @@ import logging import flask_login from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 6ddae6fad5..06f0218771 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -2,8 +2,8 @@ from datetime import datetime import pytz # pip install pytz from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -24,6 +24,8 @@ from libs.helper import DatetimeString from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode +from services.conversation_service import ConversationService +from services.errors.conversation import ConversationNotExistsError class CompletionConversationApi(Resource): @@ -46,7 +48,9 @@ class CompletionConversationApi(Resource): parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") + query = db.select(Conversation).where( + Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) + ) if args["keyword"]: query = query.join(Message, Message.conversation_id == Conversation.id).where( @@ -119,18 +123,11 @@ class CompletionConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() - ) - - if not conversation: + try: + ConversationService.delete(app_model, conversation_id, current_user) + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True - db.session.commit() - return {"result": "success"}, 204 @@ -171,7 +168,7 @@ class ChatConversationApi(Resource): .subquery() ) - query = db.select(Conversation).where(Conversation.app_id == app_model.id) + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args["keyword"]: keyword_filter = f"%{args['keyword']}%" @@ -284,18 +281,11 @@ class ChatConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() - ) - - if not conversation: + try: + ConversationService.delete(app_model, conversation_id, current_user) + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True - db.session.commit() - return {"result": "success"}, 204 diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1..5ca4c33f87 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 57dc1267d5..497fd53df7 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 2344fd5acb..541803e539 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,7 +2,7 @@ import json from enum import StrEnum from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 680ac4a64c..57cc825fe9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 029138fb6b..52ff9b923d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,7 +3,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 978c02412c..74c2867c2f 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import BadRequest from controllers.console import api diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 03418f1dd2..778ce92da6 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 343b7acd7b..27e405af38 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,7 +5,7 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index c58301b300..8dcffb1666 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import cast from flask import abort, request -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 310146a5e7..8d8cdc93cf 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,6 +1,6 @@ from dateutil.parser import isoparse -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from controllers.console import api diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 414c07ef50..4e625db24d 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -2,7 +2,7 @@ import logging from typing import Any, NoReturn from flask import Response -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9099700213..dccbfd8648 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,8 @@ from typing import cast from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7f80afd83b..7cef175c14 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -5,7 +5,7 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 2562fb5eb8..e82e403ec2 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from constants.languages import supported_language from controllers.console import api diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index b8c3c8f012..796e6916cc 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 4940b48754..d4cf20549a 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,7 +3,7 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 3bbe3177fc..ede0696854 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,7 @@ import base64 import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 5f2a24322d..a5ad6a1cd7 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -2,7 +2,7 @@ from typing import cast import flask_login from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse import services from configs import dify_config @@ -221,7 +221,7 @@ class EmailCodeLoginApi(Resource): email=user_email, name=user_email, interface_language=languages[0] ) except WorkSpaceNotAllowedCreateError: - return NotAllowedCreateWorkspace() + raise NotAllowedCreateWorkspace() except AccountRegisterError as are: raise AccountInFreezeError() except WorkspacesLimitExceededError: diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a6cb99390..3c76394cf9 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,7 @@ from typing import Optional import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c..8ebb745a60 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 9679632ac7..4bc073f679 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -1,6 +1,6 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from libs.helper import extract_remote_ip from libs.login import login_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 88b24c6985..41b4638822 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -4,7 +4,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index cbc234deb7..7080220eb5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restx from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound import services @@ -600,7 +600,7 @@ class DatasetApiKeyApi(Resource): ) if current_key_count >= self.max_keys: - flask_restful.abort( + flask_restx.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -640,7 +640,7 @@ class DatasetApiDeleteApi(Resource): ) if key is None: - flask_restful.abort(404, message="API key not found") + flask_restx.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index d4e67409fc..e99821eb4c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,7 +5,7 @@ from typing import Literal, cast from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 8c429044d7..463fd2d7ec 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,7 +2,7 @@ import uuid from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_restx import Resource, marshal, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound @@ -584,7 +584,12 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where( + ChildChunk.id == str(child_chunk_id), + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.segment_id == segment.id, + ChildChunk.document_id == document_id, + ) .first() ) if not child_chunk: @@ -633,7 +638,12 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where( + ChildChunk.id == str(child_chunk_id), + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.segment_id == segment.id, + ChildChunk.document_id == document_id, + ) .first() ) if not child_chunk: diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index cf9081e154..043f39f623 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_restx import Resource, marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index fba5d4c0f3..2ad192571b 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c076863..304674db5f 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 1b5570285d..6aa309f930 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,7 +1,7 @@ from typing import Literal from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py deleted file mode 100644 index 2afdaf7f2b..0000000000 --- a/api/controllers/console/datasets/upload_file.py +++ /dev/null @@ -1,62 +0,0 @@ -from flask_login import current_user -from flask_restful import Resource -from werkzeug.exceptions import NotFound - -from controllers.console import api -from controllers.console.wraps import ( - account_initialization_required, - setup_required, -) -from core.file import helpers as file_helpers -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import UploadFile -from services.dataset_service import DocumentService - - -class UploadFileApi(Resource): - @setup_required - @account_initialization_required - def get(self, dataset_id, document_id): - """Get upload file.""" - # check dataset - dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) - .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id) - .first() - ) - if not dataset: - raise NotFound("Dataset not found.") - # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: - raise NotFound("Document not found.") - # check upload file - if document.data_source_type != "upload_file": - raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") - data_source_info = document.data_source_info_dict - if data_source_info and "upload_file_id" in data_source_info: - file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if not upload_file: - raise NotFound("UploadFile not found.") - else: - raise ValueError("Upload file id not found in document data source info.") - - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": url, - "download_url": f"{url}&as_attachment=true", - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.timestamp(), - }, 200 - - -api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index fcdc91ec67..bdaa268462 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index d564a00a76..2a4d5be82f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -65,7 +65,7 @@ class ChatAudioApi(InstalledAppResource): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restx import reqparse app_model = installed_app.app try: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4842fefc57..b444a2a197 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index d7c161cc6d..a8d46954b5 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,6 +1,6 @@ from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ad62bd6e08..3ccedd654b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -3,7 +3,7 @@ from typing import Any from flask import request from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index de95a9e7b0..6df3bca762 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,8 +1,8 @@ import logging from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound from controllers.console.app.error import ( diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index a1280d91d1..c368744759 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restx import marshal_with from controllers.common import fields from controllers.console import api diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495aa..62f9350b71 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 339e7007a0..5353dbcad5 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3f625e6609..3d872fc1fc 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError from controllers.console.app.error import ( diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index de97fb149e..e86103184a 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 07a241ef86..e157041c35 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865..6236832d39 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index a87d270e9c..101a49a32e 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -2,7 +2,7 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restx import Resource, marshal_with from werkzeug.exceptions import Forbidden import services diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index b19e331d2e..2a37b1708a 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946e..1a53a2347e 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index c356113c40..73014cfc97 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -3,7 +3,7 @@ from typing import cast import httpx from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from controllers.common import helpers diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e1f19a87a3..8e230496f0 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index cb5dedca21..c45e7dbb26 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,11 +1,11 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required -from fields.tag_fields import tag_fields +from fields.tag_fields import dataset_tag_fields from libs.login import login_required from models.model import Tag from services.tag_service import TagService @@ -21,7 +21,7 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tag_fields) + @marshal_with(dataset_tag_fields) def get(self): tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 894785abc8..96cf627b65 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import json import logging import requests -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 3f6d6bf54f..5b2828dbab 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -3,7 +3,7 @@ from datetime import datetime import pytz from flask import request from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 88c37767e3..08bab6fcb5 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index eb53dcb16e..96e873d42b 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index b4eb5e246b..2a54511bf0 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index f7424923b9..f018fada3a 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -2,7 +2,7 @@ from urllib import parse from flask import request from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, abort, marshal_with, reqparse import services from configs import dify_config diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..281783b3d7 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -2,7 +2,7 @@ import io from flask import send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 514d1084c4..b8dddb91dd 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 09846d5c94..fd5421fa64 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -2,7 +2,7 @@ import io from flask import request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 8c8b73b45d..854ba7ac45 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_login import current_user -from flask_restful import ( +from flask_restx import ( Resource, reqparse, ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index f4f0078da7..fb89f6bbbd 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -2,7 +2,7 @@ import logging from flask import request from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from werkzeug.exceptions import Unauthorized diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index d4c3245708..821ad220a2 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -1,9 +1,20 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi -bp = Blueprint("files", __name__) -api = ExternalApi(bp) +bp = Blueprint("files", __name__, url_prefix="/files") +api = ExternalApi( + bp, + version="1.0", + title="Files API", + description="API for file operations including upload and preview", + doc="/docs", # Enable Swagger UI at /files/docs +) + +files_ns = Namespace("files", description="File operations", path="/") from . import image_preview, tool_files, upload + +api.add_namespace(files_ns) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 91f7b27d1d..48baac6556 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,16 +1,17 @@ from urllib.parse import quote from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import NotFound import services from controllers.common.errors import UnsupportedFileTypeError -from controllers.files import api +from controllers.files import files_ns from services.account_service import TenantService from services.file_service import FileService +@files_ns.route("//image-preview") class ImagePreviewApi(Resource): """ Deprecated @@ -39,6 +40,7 @@ class ImagePreviewApi(Resource): return Response(generator, mimetype=mimetype) +@files_ns.route("//file-preview") class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) @@ -94,6 +96,7 @@ class FilePreviewApi(Resource): return response +@files_ns.route("/workspaces//webapp-logo") class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) @@ -112,8 +115,3 @@ class WorkspaceWebappLogoApi(Resource): raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - - -api.add_resource(ImagePreviewApi, "/files//image-preview") -api.add_resource(FilePreviewApi, "/files//file-preview") -api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index d9c4e50511..faa9b733c2 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,17 +1,18 @@ from urllib.parse import quote from flask import Response -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError -from controllers.files import api +from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager from models import db as global_db -class ToolFilePreviewApi(Resource): +@files_ns.route("/tools/.") +class ToolFileApi(Resource): def get(self, file_id, extension): file_id = str(file_id) @@ -52,6 +53,3 @@ class ToolFilePreviewApi(Resource): response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" return response - - -api.add_resource(ToolFilePreviewApi, "/files/tools/.") diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index bcc72d131c..7a2b3b0428 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,7 +1,9 @@ from mimetypes import guess_extension +from typing import Optional -from flask import request -from flask_restful import Resource, marshal_with +from flask_restx import Resource, reqparse +from flask_restx.api import HTTPStatus +from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services @@ -10,39 +12,76 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from controllers.console.wraps import setup_required -from controllers.files import api +from controllers.files import files_ns from controllers.inner_api.plugin.wraps import get_user from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import file_fields +from fields.file_fields import build_file_model + +# Define parser for both documentation and validation +upload_parser = reqparse.RequestParser() +upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") +upload_parser.add_argument( + "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" +) +upload_parser.add_argument( + "nonce", type=str, required=True, location="args", help="Random string for signature verification" +) +upload_parser.add_argument( + "sign", type=str, required=True, location="args", help="HMAC signature for request validation" +) +upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") +upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier") +@files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @setup_required - @marshal_with(file_fields) + @files_ns.expect(upload_parser) + @files_ns.doc("upload_plugin_file") + @files_ns.doc(description="Upload a file for plugin usage with signature verification") + @files_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Invalid request parameters", + 403: "Forbidden - Invalid signature or missing parameters", + 413: "File too large", + 415: "Unsupported file type", + } + ) + @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) def post(self): - # get file from request - file = request.files["file"] + """Upload a file for plugin usage. - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - sign = request.args.get("sign") - tenant_id = request.args.get("tenant_id") - if not tenant_id: - raise Forbidden("Invalid request.") + Accepts a file upload with signature verification for security. + The file must be accompanied by valid timestamp, nonce, and signature parameters. - user_id = request.args.get("user_id") + Returns: + dict: File metadata including ID, URLs, and properties + int: HTTP status code (201 for success) + + Raises: + Forbidden: Invalid signature or missing required parameters + FileTooLargeError: File exceeds size limit + UnsupportedFileTypeError: File type not supported + """ + # Parse and validate all arguments + args = upload_parser.parse_args() + + file: FileStorage = args["file"] + timestamp: str = args["timestamp"] + nonce: str = args["nonce"] + sign: str = args["sign"] + tenant_id: str = args["tenant_id"] + user_id: Optional[str] = args.get("user_id") user = get_user(tenant_id, user_id) - filename = file.filename - mimetype = file.mimetype + filename: Optional[str] = file.filename + mimetype: Optional[str] = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") - if not timestamp or not nonce or not sign: - raise Forbidden("Invalid request.") - if not verify_plugin_file_signature( filename=filename, mimetype=mimetype, @@ -88,6 +127,3 @@ class PluginUploadFileApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - -api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 7b96f88f51..80bbc360de 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 5dfe41eb6b..9b8d9457f0 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index b533614d4d..89b4ac7506 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -4,7 +4,7 @@ from typing import Optional from flask import current_app, request from flask_login import user_logged_in -from flask_restful import reqparse +from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 77568b75f1..1c26416080 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,6 +1,6 @@ import json -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 1b3e0a5621..c344ffad08 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -1,8 +1,20 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("mcp", __name__, url_prefix="/mcp") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="MCP API", + description="API for Model Context Protocol operations", + doc="/docs", # Enable Swagger UI at /mcp/docs +) + +mcp_ns = Namespace("mcp", description="MCP operations", path="/") from . import mcp + +api.add_namespace(mcp_ns) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 87d678796f..fc19749011 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,8 +1,10 @@ -from flask_restful import Resource, reqparse +from typing import Optional, Union + +from flask_restx import Resource, reqparse from pydantic import ValidationError from controllers.console.app.mcp_server import AppMCPServerStatus -from controllers.mcp import api +from controllers.mcp import mcp_ns from core.app.app_config.entities import VariableEntity from core.mcp import types from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler @@ -13,22 +15,58 @@ from libs import helper from models.model import App, AppMCPServer, AppMode +def int_or_str(value): + """Validate that a value is either an integer or string.""" + if isinstance(value, (int, str)): + return value + else: + return None + + +# Define parser for both documentation and validation +mcp_request_parser = reqparse.RequestParser() +mcp_request_parser.add_argument( + "jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')" +) +mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke") +mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") +mcp_request_parser.add_argument( + "id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses" +) + + +@mcp_ns.route("/server//mcp") class MCPAppApi(Resource): - def post(self, server_code): - def int_or_str(value): - if isinstance(value, (int, str)): - return value - else: - return None + @mcp_ns.expect(mcp_request_parser) + @mcp_ns.doc("handle_mcp_request") + @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") + @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) + @mcp_ns.doc( + responses={ + 200: "MCP response successfully processed", + 400: "Invalid MCP request or parameters", + 404: "Server or app not found", + } + ) + def post(self, server_code: str): + """Handle MCP requests for a specific server. - parser = reqparse.RequestParser() - parser.add_argument("jsonrpc", type=str, required=True, location="json") - parser.add_argument("method", type=str, required=True, location="json") - parser.add_argument("params", type=dict, required=False, location="json") - parser.add_argument("id", type=int_or_str, required=False, location="json") - args = parser.parse_args() + Processes JSON-RPC formatted requests according to the Model Context Protocol specification. + Validates the server status and associated app before processing the request. - request_id = args.get("id") + Args: + server_code: Unique identifier for the MCP server + + Returns: + dict: JSON-RPC response from the MCP handler + + Raises: + ValidationError: Invalid request format or parameters + """ + # Parse and validate all arguments + args = mcp_request_parser.parse_args() + + request_id: Optional[Union[int, str]] = args.get("id") server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() if not server: @@ -99,6 +137,3 @@ class MCPAppApi(Resource): mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) response = mcp_server_handler.handle() return helper.compact_generate_response(response) - - -api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index b26f29d98d..763345d723 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,11 +1,23 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("service_api", __name__, url_prefix="/v1") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Service API", + description="API for application services", + doc="/docs", # Enable Swagger UI at /v1/docs +) + +service_api_ns = Namespace("service_api", description="Service operations", path="/") from . import index from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models + +api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 23446bb702..6bc94af8c1 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,28 +1,51 @@ from typing import Literal from flask import request -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.api import HTTPStatus from werkzeug.exceptions import Forbidden -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import ( - annotation_fields, -) +from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user from models.model import App from services.annotation_service import AppAnnotationService +# Define parsers for annotation API +annotation_create_parser = reqparse.RequestParser() +annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") +annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") +annotation_reply_action_parser = reqparse.RequestParser() +annotation_reply_action_parser.add_argument( + "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" +) +annotation_reply_action_parser.add_argument( + "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" +) +annotation_reply_action_parser.add_argument( + "embedding_model_name", required=True, type=str, location="json", help="Embedding model name" +) + + +@service_api_ns.route("/apps/annotation-reply/") class AnnotationReplyActionApi(Resource): + @service_api_ns.expect(annotation_reply_action_parser) + @service_api_ns.doc("annotation_reply_action") + @service_api_ns.doc(description="Enable or disable annotation reply feature") + @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): - parser = reqparse.RequestParser() - parser.add_argument("score_threshold", required=True, type=float, location="json") - parser.add_argument("embedding_provider_name", required=True, type=str, location="json") - parser.add_argument("embedding_model_name", required=True, type=str, location="json") - args = parser.parse_args() + """Enable or disable annotation reply feature.""" + args = annotation_reply_action_parser.parse_args() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": @@ -30,9 +53,21 @@ class AnnotationReplyActionApi(Resource): return result, 200 +@service_api_ns.route("/apps/annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @service_api_ns.doc("get_annotation_reply_action_status") + @service_api_ns.doc(description="Get the status of an annotation reply action job") + @service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"}) + @service_api_ns.doc( + responses={ + 200: "Job status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Job not found", + } + ) @validate_app_token def get(self, app_model: App, job_id, action): + """Get the status of an annotation reply action job.""" job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -48,60 +83,111 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +# Define annotation list response model +annotation_list_fields = { + "data": fields.List(fields.Nested(annotation_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, + "total": fields.Integer, + "page": fields.Integer, +} + + +def build_annotation_list_model(api_or_ns: Api | Namespace): + """Build the annotation list model for the API or Namespace.""" + copied_annotation_list_fields = annotation_list_fields.copy() + copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) + return api_or_ns.model("AnnotationList", copied_annotation_list_fields) + + +@service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): + @service_api_ns.doc("list_annotations") + @service_api_ns.doc(description="List annotations for the application") + @service_api_ns.doc( + responses={ + 200: "Annotations retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token + @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): + """List annotations for the application.""" page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), + return { + "data": annotation_list, "has_more": len(annotation_list) == limit, "limit": limit, "total": total, "page": page, } - return response, 200 + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("create_annotation") + @service_api_ns.doc(description="Create a new annotation") + @service_api_ns.doc( + responses={ + 201: "Annotation created successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + """Create a new annotation.""" + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation + return annotation, 201 +@service_api_ns.route("/apps/annotations/") class AnnotationUpdateDeleteApi(Resource): + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("update_annotation") + @service_api_ns.doc(description="Update an existing annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 200: "Annotation updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id): + """Update an existing annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation + @service_api_ns.doc("delete_annotation") + @service_api_ns.doc(description="Delete an annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 204: "Annotation deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token def delete(self, app_model: App, annotation_id): + """Delete an annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) return {"result": "success"}, 204 - - -api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/") -api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply//status/") -api.add_resource(AnnotationListApi, "/apps/annotations") -api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/") diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 89222d5e83..2dbeed1d68 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,7 +1,7 @@ -from flask_restful import Resource, marshal_with +from flask_restx import Resource -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_parameters_model +from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict @@ -9,13 +9,26 @@ from models.model import App, AppMode from services.app_service import AppService +@service_api_ns.route("/parameters") class AppParameterApi(Resource): """Resource for app variables.""" + @service_api_ns.doc("get_app_parameters") + @service_api_ns.doc(description="Retrieve application input parameters and configuration") + @service_api_ns.doc( + responses={ + 200: "Parameters retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token - @marshal_with(fields.parameters_fields) + @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app parameters.""" + """Retrieve app parameters. + + Returns the input form parameters and configuration for the application. + """ if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -35,17 +48,43 @@ class AppParameterApi(Resource): return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) +@service_api_ns.route("/meta") class AppMetaApi(Resource): + @service_api_ns.doc("get_app_meta") + @service_api_ns.doc(description="Get application metadata") + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app meta""" + """Get app metadata. + + Returns metadata about the application including configuration and settings. + """ return AppService().get_app_meta(app_model) +@service_api_ns.route("/info") class AppInfoApi(Resource): + @service_api_ns.doc("get_app_info") + @service_api_ns.doc(description="Get basic application information") + @service_api_ns.doc( + responses={ + 200: "Application info retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app information""" + """Get app information. + + Returns basic information about the application including name, description, tags, and mode. + """ tags = [tag.name for tag in app_model.tags] return { "name": app_model.name, @@ -54,8 +93,3 @@ class AppInfoApi(Resource): "mode": app_model.mode, "author_name": app_model.author_name, } - - -api.add_resource(AppParameterApi, "/parameters") -api.add_resource(AppMetaApi, "/meta") -api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 848863cf1b..61b3020a5f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -30,9 +30,26 @@ from services.errors.audio import ( ) +@service_api_ns.route("/audio-to-text") class AudioApi(Resource): + @service_api_ns.doc("audio_to_text") + @service_api_ns.doc(description="Convert audio to text using speech-to-text") + @service_api_ns.doc( + responses={ + 200: "Audio successfully transcribed", + 400: "Bad request - no audio or invalid audio", + 401: "Unauthorized - invalid API token", + 413: "Audio file too large", + 415: "Unsupported audio type", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): + """Convert audio to text using speech-to-text. + + Accepts an audio file upload and returns the transcribed text. + """ file = request.files["file"] try: @@ -65,16 +82,35 @@ class AudioApi(Resource): raise InternalServerError() +# Define parser for text-to-audio API +text_to_audio_parser = reqparse.RequestParser() +text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") +text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") +text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") +text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") + + +@service_api_ns.route("/text-to-audio") class TextApi(Resource): + @service_api_ns.expect(text_to_audio_parser) + @service_api_ns.doc("text_to_audio") + @service_api_ns.doc(description="Convert text to audio using text-to-speech") + @service_api_ns.doc( + responses={ + 200: "Text successfully converted to audio", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) def post(self, app_model: App, end_user: EndUser): + """Convert text to audio using text-to-speech. + + Converts the provided text to audio using the specified voice. + """ try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + args = text_to_audio_parser.parse_args() message_id = args.get("message_id", None) text = args.get("text", None) @@ -108,7 +144,3 @@ class TextApi(Resource): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index ea57f04850..dddb75d593 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -33,21 +33,68 @@ from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +# Define parser for completion API +completion_parser = reqparse.RequestParser() +completion_parser.add_argument( + "inputs", type=dict, required=True, location="json", help="Input parameters for completion" +) +completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") +completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +completion_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +completion_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +# Define parser for chat API +chat_parser = reqparse.RequestParser() +chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") +chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") +chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +chat_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") +chat_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +chat_parser.add_argument( + "auto_generate_name", + type=bool, + required=False, + default=True, + location="json", + help="Auto generate conversation name", +) +chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") + + +@service_api_ns.route("/completion-messages") class CompletionApi(Resource): + @service_api_ns.expect(completion_parser) + @service_api_ns.doc("create_completion") + @service_api_ns.doc(description="Create a completion for the given prompt") + @service_api_ns.doc( + responses={ + 200: "Completion created successfully", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Create a completion for the given prompt. + + This endpoint generates a completion based on the provided inputs and query. + Supports both blocking and streaming response modes. + """ if app_model.mode != "completion": raise AppUnavailableError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - - args = parser.parse_args() + args = completion_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -88,9 +135,21 @@ class CompletionApi(Resource): raise InternalServerError() +@service_api_ns.route("/completion-messages//stop") class CompletionStopApi(Resource): + @service_api_ns.doc("stop_completion") + @service_api_ns.doc(description="Stop a running completion task") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running completion task.""" if app_model.mode != "completion": raise AppUnavailableError() @@ -99,23 +158,33 @@ class CompletionStopApi(Resource): return {"result": "success"}, 200 +@service_api_ns.route("/chat-messages") class ChatApi(Resource): + @service_api_ns.expect(chat_parser) + @service_api_ns.doc("create_chat_message") + @service_api_ns.doc(description="Send a message in a chat conversation") + @service_api_ns.doc( + responses={ + 200: "Message sent successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Conversation or workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Send a message in a chat conversation. + + This endpoint handles chat messages for chat, agent chat, and advanced chat applications. + Supports conversation management and both blocking and streaming response modes. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - parser.add_argument("workflow_id", type=str, required=False, location="json") - args = parser.parse_args() + args = chat_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: @@ -159,9 +228,21 @@ class ChatApi(Resource): raise InternalServerError() +@service_api_ns.route("/chat-messages//stop") class ChatStopApi(Resource): + @service_api_ns.doc("stop_chat_message") + @service_api_ns.doc(description="Stop a running chat message generation") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running chat message generation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -169,9 +250,3 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 073307ac4a..4860bf3a79 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,48 +1,97 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - conversation_delete_fields, - conversation_infinite_scroll_pagination_fields, - simple_conversation_fields, + build_conversation_delete_model, + build_conversation_infinite_scroll_pagination_model, + build_simple_conversation_model, ) from fields.conversation_variable_fields import ( - conversation_variable_fields, - conversation_variable_infinite_scroll_pagination_fields, + build_conversation_variable_infinite_scroll_pagination_model, + build_conversation_variable_model, ) from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService +# Define parsers for conversation APIs +conversation_list_parser = reqparse.RequestParser() +conversation_list_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" +) +conversation_list_parser.add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of conversations to return", +) +conversation_list_parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + help="Sort order for conversations", +) +conversation_rename_parser = reqparse.RequestParser() +conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") +conversation_rename_parser.add_argument( + "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" +) + +conversation_variables_parser = reqparse.RequestParser() +conversation_variables_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last variable ID for pagination" +) +conversation_variables_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" +) + +conversation_variable_update_parser = reqparse.RequestParser() +# using lambda is for passing the already-typed value without modification +# if no lambda, it will be converted to string +# the string cannot be converted using json.loads +conversation_variable_update_parser.add_argument( + "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" +) + + +@service_api_ns.route("/conversations") class ConversationApi(Resource): + @service_api_ns.expect(conversation_list_parser) + @service_api_ns.doc("list_conversations") + @service_api_ns.doc(description="List all conversations for the current user") + @service_api_ns.doc( + responses={ + 200: "Conversations retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Last conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List all conversations for the current user. + + Supports pagination using last_id and limit parameters. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - args = parser.parse_args() + args = conversation_list_parser.parse_args() try: with Session(db.engine) as session: @@ -59,10 +108,22 @@ class ConversationApi(Resource): raise NotFound("Last Conversation Not Exists.") +@service_api_ns.route("/conversations/") class ConversationDetailApi(Resource): + @service_api_ns.doc("delete_conversation") + @service_api_ns.doc(description="Delete a specific conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(conversation_delete_fields) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) def delete(self, app_model: App, end_user: EndUser, c_id): + """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -76,20 +137,30 @@ class ConversationDetailApi(Resource): return {"result": "success"}, 204 +@service_api_ns.route("/conversations//name") class ConversationRenameApi(Resource): + @service_api_ns.expect(conversation_rename_parser) + @service_api_ns.doc("rename_conversation") + @service_api_ns.doc(description="Rename a conversation or auto-generate a name") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(simple_conversation_fields) + @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) def post(self, app_model: App, end_user: EndUser, c_id): + """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, location="json") - parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") - args = parser.parse_args() + args = conversation_rename_parser.parse_args() try: return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) @@ -97,10 +168,26 @@ class ConversationRenameApi(Resource): raise NotFound("Conversation Not Exists.") +@service_api_ns.route("/conversations//variables") class ConversationVariablesApi(Resource): + @service_api_ns.expect(conversation_variables_parser) + @service_api_ns.doc("list_conversation_variables") + @service_api_ns.doc(description="List all variables for a conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Variables retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_variable_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): + """List all variables for a conversation. + + Conversational variables are only available for chat applications. + """ # conversational variable only for chat app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -108,10 +195,7 @@ class ConversationVariablesApi(Resource): conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = conversation_variables_parser.parse_args() try: return ConversationService.get_conversational_variable( @@ -121,11 +205,28 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") +@service_api_ns.route("/conversations//variables/") class ConversationVariableDetailApi(Resource): + @service_api_ns.expect(conversation_variable_update_parser) + @service_api_ns.doc("update_conversation_variable") + @service_api_ns.doc(description="Update a conversation variable's value") + @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) + @service_api_ns.doc( + responses={ + 200: "Variable updated successfully", + 400: "Bad request - type mismatch", + 401: "Unauthorized - invalid API token", + 404: "Conversation or variable not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(conversation_variable_fields) + @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): - """Update a conversation variable's value""" + """Update a conversation variable's value. + + Allows updating the value of a specific conversation variable. + The value must match the variable's expected type. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -133,12 +234,7 @@ class ConversationVariableDetailApi(Resource): conversation_id = str(c_id) variable_id = str(variable_id) - parser = reqparse.RequestParser() - # using lambda is for passing the already-typed value without modification - # if no lambda, it will be converted to string - # the string cannot be converted using json.loads - parser.add_argument("value", required=True, location="json", type=lambda x: x) - args = parser.parse_args() + args = conversation_variable_update_parser.parse_args() try: return ConversationService.update_conversation_variable( @@ -150,15 +246,3 @@ class ConversationVariableDetailApi(Resource): raise NotFound("Conversation Variable Not Exists.") except services.errors.conversation.ConversationVariableTypeMismatchError as e: raise BadRequest(str(e)) - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") -api.add_resource(ConversationApi, "/conversations") -api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") -api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") -api.add_resource( - ConversationVariableDetailApi, - "/conversations//variables/", - endpoint="conversation_variable_detail", - methods=["PUT"], -) diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 37153ca5db..05f27545b3 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,6 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restx import Resource +from flask_restx.api import HTTPStatus import services from controllers.common.errors import ( @@ -9,17 +10,33 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from fields.file_fields import file_fields +from fields.file_fields import build_file_model from models.model import App, EndUser from services.file_service import FileService +@service_api_ns.route("/files/upload") class FileApi(Resource): + @service_api_ns.doc("upload_file") + @service_api_ns.doc(description="Upload a file for use in conversations") + @service_api_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Bad request - no file or invalid file", + 401: "Unauthorized - invalid API token", + 413: "File too large", + 415: "Unsupported file type", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @marshal_with(file_fields) + @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App, end_user: EndUser): + """Upload a file for use in conversations. + + Accepts a single file upload via multipart/form-data. + """ # check file if "file" not in request.files: raise NoFileUploadedError() @@ -47,6 +64,3 @@ class FileApi(Resource): raise UnsupportedFileTypeError() return upload_file, 201 - - -api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index 57141033d1..84d80ea101 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -2,9 +2,9 @@ import logging from urllib.parse import quote from flask import Response -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( FileAccessDeniedError, FileNotFoundError, @@ -17,6 +17,14 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile logger = logging.getLogger(__name__) +# Define parser for file preview API +file_preview_parser = reqparse.RequestParser() +file_preview_parser.add_argument( + "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" +) + + +@service_api_ns.route("/files//preview") class FilePreviewApi(Resource): """ Service API File Preview endpoint @@ -25,33 +33,30 @@ class FilePreviewApi(Resource): Files can only be accessed if they belong to messages within the requesting app's context. """ + @service_api_ns.expect(file_preview_parser) + @service_api_ns.doc("preview_file") + @service_api_ns.doc(description="Preview or download a file uploaded via Service API") + @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) + @service_api_ns.doc( + responses={ + 200: "File retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - file access denied", + 404: "File not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) def get(self, app_model: App, end_user: EndUser, file_id: str): """ - Preview/Download a file that was uploaded via Service API + Preview/Download a file that was uploaded via Service API. - Args: - app_model: The authenticated app model - end_user: The authenticated end user (optional) - file_id: UUID of the file to preview - - Query Parameters: - user: Optional user identifier - as_attachment: Boolean, whether to download as attachment (default: false) - - Returns: - Stream response with file content - - Raises: - FileNotFoundError: File does not exist - FileAccessDeniedError: File access denied (not owned by app) + Provides secure file preview/download functionality. + Files can only be accessed if they belong to messages within the requesting app's context. """ file_id = str(file_id) # Parse query parameters - parser = reqparse.RequestParser() - parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") - args = parser.parse_args() + args = file_preview_parser.parse_args() # Validate file ownership and get file objects message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) @@ -180,7 +185,3 @@ class FilePreviewApi(Resource): response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour return response - - -# Register the API endpoint -api.add_resource(FilePreviewApi, "/files//preview") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index a4f95cb1cb..ad3fac7009 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,17 +1,17 @@ import json import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields +from fields.conversation_fields import build_message_file_model +from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from models.model import App, AppMode, EndUser @@ -22,8 +22,37 @@ from services.errors.message import ( ) from services.message_service import MessageService +# Define parsers for message APIs +message_list_parser = reqparse.RequestParser() +message_list_parser.add_argument( + "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" +) +message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") +message_list_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" +) -class MessageListApi(Resource): +message_feedback_parser = reqparse.RequestParser() +message_feedback_parser.add_argument( + "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" +) +message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content") + +feedback_list_parser = reqparse.RequestParser() +feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") +feedback_list_parser.add_argument( + "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" +) + + +def build_message_model(api_or_ns: Api | Namespace): + """Build the message model for the API or Namespace.""" + # First build the nested models + feedback_model = build_feedback_model(api_or_ns) + agent_thought_model = build_agent_thought_model(api_or_ns) + message_file_model = build_message_file_model(api_or_ns) + + # Then build the message fields with nested models message_fields = { "id": fields.String, "conversation_id": fields.String, @@ -31,37 +60,58 @@ class MessageListApi(Resource): "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "message_files": fields.List(fields.Nested(message_file_model)), + "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), "retriever_resources": fields.Raw( attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) if obj.message_metadata else [] ), "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "status": fields.String, "error": fields.String, } + return api_or_ns.model("Message", message_fields) + + +def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the message infinite scroll pagination model for the API or Namespace.""" + # Build the nested message model first + message_model = build_message_model(api_or_ns) message_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), + "data": fields.List(fields.Nested(message_model)), } + return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) + +@service_api_ns.route("/messages") +class MessageListApi(Resource): + @service_api_ns.expect(message_list_parser) + @service_api_ns.doc("list_messages") + @service_api_ns.doc(description="List messages in a conversation") + @service_api_ns.doc( + responses={ + 200: "Messages retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation or first message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(message_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List messages in a conversation. + + Retrieves messages with pagination support using first_id. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = message_list_parser.parse_args() try: return MessageService.pagination_by_first_id( @@ -73,15 +123,28 @@ class MessageListApi(Resource): raise NotFound("First Message Not Exists.") +@service_api_ns.route("/messages//feedbacks") class MessageFeedbackApi(Resource): + @service_api_ns.expect(message_feedback_parser) + @service_api_ns.doc("create_message_feedback") + @service_api_ns.doc(description="Submit feedback for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, message_id): + """Submit feedback for a message. + + Allows users to rate messages as like/dislike and provide optional feedback content. + """ message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - parser.add_argument("content", type=str, location="json") - args = parser.parse_args() + args = message_feedback_parser.parse_args() try: MessageService.create_feedback( @@ -97,21 +160,48 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +@service_api_ns.route("/app/feedbacks") class AppGetFeedbacksApi(Resource): + @service_api_ns.expect(feedback_list_parser) + @service_api_ns.doc("get_app_feedbacks") + @service_api_ns.doc(description="Get all feedbacks for the application") + @service_api_ns.doc( + responses={ + 200: "Feedbacks retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token def get(self, app_model: App): - """Get All Feedbacks of an app""" - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, default=1, location="args") - parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args") - args = parser.parse_args() + """Get all feedbacks for the application. + + Returns paginated list of all feedback submitted for messages in this app. + """ + args = feedback_list_parser.parse_args() feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) return {"data": feedbacks} +@service_api_ns.route("/messages//suggested") class MessageSuggestedApi(Resource): + @service_api_ns.doc("get_suggested_questions") + @service_api_ns.doc(description="Get suggested follow-up questions for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Suggested questions retrieved successfully", + 400: "Suggested questions feature is disabled", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) def get(self, app_model: App, end_user: EndUser, message_id): + """Get suggested follow-up questions for a message. + + Returns AI-generated follow-up questions based on the message content. + """ message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -130,9 +220,3 @@ class MessageSuggestedApi(Resource): raise InternalServerError() return {"result": "success", "data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageSuggestedApi, "/messages//suggested") -api.add_resource(AppGetFeedbacksApi, "/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index c157b39f6b..9f8324a84e 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,30 +1,41 @@ -from flask_restful import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_site_model +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db from models.account import TenantStatus from models.model import App, Site +@service_api_ns.route("/site") class AppSiteApi(Resource): """Resource for app sites.""" + @service_api_ns.doc("get_app_site") + @service_api_ns.doc(description="Get application site configuration") + @service_api_ns.doc( + responses={ + 200: "Site configuration retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - site not found or tenant archived", + } + ) @validate_app_token - @marshal_with(fields.site_fields) + @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app site info.""" + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() + assert app_model.tenant if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() return site - - -api.add_resource(AppSiteApi, "/site") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index cd8a5f03ac..19e2e67d7f 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -2,12 +2,12 @@ import logging from dateutil.parser import isoparse from flask import request -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( CompletionRequestError, NotWorkflowAppError, @@ -28,7 +28,7 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser @@ -40,6 +40,34 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) +# Define parsers for workflow APIs +workflow_run_parser = reqparse.RequestParser() +workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") +workflow_run_parser.add_argument("files", type=list, required=False, location="json") +workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + +workflow_log_parser = reqparse.RequestParser() +workflow_log_parser.add_argument("keyword", type=str, location="args") +workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") +workflow_log_parser.add_argument("created_at__before", type=str, location="args") +workflow_log_parser.add_argument("created_at__after", type=str, location="args") +workflow_log_parser.add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") +workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") + workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, @@ -55,12 +83,29 @@ workflow_run_fields = { } +def build_workflow_run_model(api_or_ns: Api | Namespace): + """Build the workflow run model for the API or Namespace.""" + return api_or_ns.model("WorkflowRun", workflow_run_fields) + + +@service_api_ns.route("/workflows/run/") class WorkflowRunDetailApi(Resource): + @service_api_ns.doc("get_workflow_run_detail") + @service_api_ns.doc(description="Get workflow run details") + @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"}) + @service_api_ns.doc( + responses={ + 200: "Workflow run details retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) @validate_app_token - @marshal_with(workflow_run_fields) + @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) def get(self, app_model: App, workflow_run_id: str): - """ - Get a workflow task running detail + """Get a workflow task running detail. + + Returns detailed information about a specific workflow run. """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: @@ -78,21 +123,33 @@ class WorkflowRunDetailApi(Resource): return workflow_run +@service_api_ns.route("/workflows/run") class WorkflowRunApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow") + @service_api_ns.doc(description="Execute a workflow") + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - """ - Run workflow + """Execute a workflow. + + Runs a workflow with the provided inputs and returns the results. + Supports both blocking and streaming response modes. """ app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - args = parser.parse_args() + args = workflow_run_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -121,21 +178,33 @@ class WorkflowRunApi(Resource): raise InternalServerError() +@service_api_ns.route("/workflows//run") class WorkflowRunByIdApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow_by_id") + @service_api_ns.doc(description="Execute a specific workflow by ID") + @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, workflow_id: str): - """ - Run specific workflow by ID + """Run specific workflow by ID. + + Executes a specific workflow version identified by its ID. """ app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - args = parser.parse_args() + args = workflow_run_parser.parse_args() # Add workflow_id to args for AppGenerateService args["workflow_id"] = workflow_id @@ -174,12 +243,21 @@ class WorkflowRunByIdApi(Resource): raise InternalServerError() +@service_api_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(Resource): + @service_api_ns.doc("stop_workflow_task") + @service_api_ns.doc(description="Stop a running workflow task") + @service_api_ns.doc(params={"task_id": "Task ID to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id: str): - """ - Stop workflow task - """ + """Stop a running workflow task.""" app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() @@ -189,35 +267,25 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +@service_api_ns.route("/workflows/logs") class WorkflowAppLogApi(Resource): + @service_api_ns.expect(workflow_log_parser) + @service_api_ns.doc("get_workflow_logs") + @service_api_ns.doc(description="Get workflow execution logs") + @service_api_ns.doc( + responses={ + 200: "Logs retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(workflow_app_log_pagination_fields) + @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) def get(self, app_model: App): + """Get workflow app logs. + + Returns paginated workflow execution logs with filtering options. """ - Get workflow app logs - """ - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") - parser.add_argument("created_at__before", type=str, location="args") - parser.add_argument("created_at__after", type=str, location="args") - parser.add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") - args = parser.parse_args() + args = workflow_log_parser.parse_args() args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: @@ -243,10 +311,3 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowRunDetailApi, "/workflows/run/") -api.add_resource(WorkflowRunByIdApi, "/workflows//run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") -api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 35b1efeff6..c486b0480b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,11 +1,11 @@ from typing import Literal from flask import request -from flask_restful import marshal, marshal_with, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, @@ -16,7 +16,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import tag_fields +from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -36,12 +36,171 @@ def _validate_description_length(description): return description +# Define parsers for dataset operations +dataset_create_parser = reqparse.RequestParser() +dataset_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_create_parser.add_argument( + "description", + type=_validate_description_length, + nullable=True, + required=False, + default="", +) +dataset_create_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", +) +dataset_create_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, +) +dataset_create_parser.add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + default="_validate_name", +) +dataset_create_parser.add_argument( + "provider", + type=str, + nullable=True, + required=False, + default="vendor", +) +dataset_create_parser.add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, +) +dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") + +dataset_update_parser = reqparse.RequestParser() +dataset_update_parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_update_parser.add_argument( + "description", location="json", store_missing=False, type=_validate_description_length +) +dataset_update_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", +) +dataset_update_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", +) +dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") +dataset_update_parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." +) +dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") +dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") +dataset_update_parser.add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", +) +dataset_update_parser.add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", +) +dataset_update_parser.add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", +) + +tag_create_parser = reqparse.RequestParser() +tag_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) + +tag_update_parser = reqparse.RequestParser() +tag_update_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) +tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_delete_parser = reqparse.RequestParser() +tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_binding_parser = reqparse.RequestParser() +tag_binding_parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." +) +tag_binding_parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." +) + +tag_unbinding_parser = reqparse.RequestParser() +tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") +tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + + +@service_api_ns.route("/datasets") class DatasetListApi(DatasetApiResource): """Resource for datasets.""" + @service_api_ns.doc("list_datasets") + @service_api_ns.doc(description="List all datasets") + @service_api_ns.doc( + responses={ + 200: "Datasets retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) # provider = request.args.get("provider", default="vendor") @@ -76,65 +235,20 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @service_api_ns.expect(dataset_create_parser) + @service_api_ns.doc("create_dataset") + @service_api_ns.doc(description="Create a new dataset") + @service_api_ns.doc( + responses={ + 200: "Dataset created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, - ) - parser.add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", - ) - parser.add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", - ) - parser.add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = dataset_create_parser.parse_args() if args.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( @@ -174,9 +288,21 @@ class DatasetListApi(DatasetApiResource): return marshal(dataset, dataset_detail_fields), 200 +@service_api_ns.route("/datasets/") class DatasetApi(DatasetApiResource): """Resource for dataset.""" + @service_api_ns.doc("get_dataset") + @service_api_ns.doc(description="Get a specific dataset by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) def get(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -216,6 +342,18 @@ class DatasetApi(DatasetApiResource): return data, 200 + @service_api_ns.expect(dataset_update_parser) + @service_api_ns.doc("update_dataset") + @service_api_ns.doc(description="Update an existing dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) @@ -223,63 +361,7 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - parser.add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - - parser.add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - - parser.add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - - parser.add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - args = parser.parse_args() + args = dataset_update_parser.parse_args() data = request.get_json() # check embedding model setting @@ -327,6 +409,17 @@ class DatasetApi(DatasetApiResource): return result_data, 200 + @service_api_ns.doc("delete_dataset") + @service_api_ns.doc(description="Delete a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 204: "Dataset deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + 409: "Conflict - dataset is in use", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ @@ -357,9 +450,27 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +@service_api_ns.route("/datasets//documents/status/") class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" + @service_api_ns.doc("update_document_status") + @service_api_ns.doc(description="Batch update document status") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'", + } + ) + @service_api_ns.doc( + responses={ + 200: "Document status updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + 400: "Bad request - invalid action", + } + ) def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. @@ -407,53 +518,65 @@ class DocumentStatusApi(DatasetApiResource): return {"result": "success"}, 200 +@service_api_ns.route("/datasets/tags") class DatasetTagsApi(DatasetApiResource): + @service_api_ns.doc("list_dataset_tags") + @service_api_ns.doc(description="Get all knowledge type tags") + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token - @marshal_with(tag_fields) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" tags = TagService.get_tags("knowledge", current_user.current_tenant_id) return tags, 200 + @service_api_ns.expect(tag_create_parser) + @service_api_ns.doc("create_dataset_tag") + @service_api_ns.doc(description="Add a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag created successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - - args = parser.parse_args() + args = tag_create_parser.parse_args() args["type"] = "knowledge" tag = TagService.save_tags(args) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - return response, 200 + @service_api_ns.expect(tag_update_parser) + @service_api_ns.doc("update_dataset_tag") + @service_api_ns.doc(description="Update a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_update_parser.parse_args() args["type"] = "knowledge" tag = TagService.update_tags(args, args.get("tag_id")) @@ -463,66 +586,88 @@ class DatasetTagsApi(DatasetApiResource): return response, 200 + @service_api_ns.expect(tag_delete_parser) + @service_api_ns.doc("delete_dataset_tag") + @service_api_ns.doc(description="Delete a knowledge type tag") + @service_api_ns.doc( + responses={ + 204: "Tag deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def delete(self, _, dataset_id): """Delete a knowledge type tag.""" if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) return 204 - @staticmethod - def _validate_tag_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name - +@service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): + @service_api_ns.expect(tag_binding_parser) + @service_api_ns.doc("bind_dataset_tags") + @service_api_ns.doc(description="Bind tags to a dataset") + @service_api_ns.doc( + responses={ + 204: "Tags bound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." - ) - parser.add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." - ) - - args = parser.parse_args() + args = tag_binding_parser.parse_args() args["type"] = "knowledge" TagService.save_tag_binding(args) return 204 +@service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): + @service_api_ns.expect(tag_unbinding_parser) + @service_api_ns.doc("unbind_dataset_tag") + @service_api_ns.doc(description="Unbind a tag from a dataset") + @service_api_ns.doc( + responses={ + 204: "Tag unbound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - - args = parser.parse_args() + args = tag_unbinding_parser.parse_args() args["type"] = "knowledge" TagService.delete_tag_binding(args) return 204 +@service_api_ns.route("/datasets//tags") class DatasetTagsBindingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_dataset_tags_binding_status") + @service_api_ns.doc(description="Get tags bound to a specific dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" @@ -531,12 +676,3 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} return response, 200 - - -api.add_resource(DatasetListApi, "/datasets") -api.add_resource(DatasetApi, "/datasets/") -api.add_resource(DocumentStatusApi, "/datasets//documents/status/") -api.add_resource(DatasetTagsApi, "/datasets/tags") -api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") -api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") -api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d0354f7851..43232229c8 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,7 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -13,7 +13,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, @@ -34,32 +34,64 @@ from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService +# Define parsers for document operations +document_text_create_parser = reqparse.RequestParser() +document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") +document_text_create_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" +) +document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +document_text_create_parser.add_argument( + "embedding_model_provider", type=str, required=False, nullable=True, location="json" +) +document_text_update_parser = reqparse.RequestParser() +document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_update_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + + +@service_api_ns.route( + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.expect(document_text_create_parser) + @service_api_ns.doc("create_document_by_text") + @service_api_ns.doc(description="Create a new document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("text", type=str, required=True, nullable=False, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("original_document_id", type=str, required=False, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = document_text_create_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -117,23 +149,29 @@ class DocumentAddByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.expect(document_text_update_parser) + @service_api_ns.doc("update_document_by_text") + @service_api_ns.doc(description="Update an existing document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("text", type=str, required=False, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - args = parser.parse_args() + args = document_text_update_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -187,9 +225,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.doc("create_document_by_file") + @service_api_ns.doc(description="Create a new document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid file or parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") @@ -281,9 +333,23 @@ class DocumentAddByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.doc("update_document_by_file") + @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): @@ -358,7 +424,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route("/datasets//documents") class DocumentListApi(DatasetApiResource): + @service_api_ns.doc("list_documents") + @service_api_ns.doc(description="List all documents in a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -391,7 +468,18 @@ class DocumentListApi(DatasetApiResource): return response +@service_api_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_document_indexing_status") + @service_api_ns.doc(description="Get indexing status for documents in a batch") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"}) + @service_api_ns.doc( + responses={ + 200: "Indexing status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or documents not found", + } + ) def get(self, tenant_id, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) @@ -440,9 +528,21 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data +@service_api_ns.route("/datasets//documents/") class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} + @service_api_ns.doc("get_document") + @service_api_ns.doc(description="Get a specific document by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document not found", + } + ) def get(self, tenant_id, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) @@ -534,6 +634,17 @@ class DocumentApi(DatasetApiResource): return response + @service_api_ns.doc("delete_document") + @service_api_ns.doc(description="Delete a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 204: "Document deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - document is archived", + 404: "Document not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id): """Delete document.""" @@ -564,28 +675,3 @@ class DocumentApi(DatasetApiResource): raise DocumentIndexingError("Cannot delete document during indexing.") return 204 - - -api.add_resource( - DocumentAddByTextApi, - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) -api.add_resource( - DocumentAddByFileApi, - "/datasets//document/create_by_file", - "/datasets//document/create-by-file", -) -api.add_resource( - DocumentUpdateByTextApi, - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) -api.add_resource( - DocumentUpdateByFileApi, - "/datasets//documents//update_by_file", - "/datasets//documents//update-by-file", -) -api.add_resource(DocumentApi, "/datasets//documents/") -api.add_resource(DocumentListApi, "/datasets//documents") -api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 52e9bca5da..d81287d56f 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,11 +1,26 @@ from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +@service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + @service_api_ns.doc("dataset_hit_testing") + @service_api_ns.doc(description="Perform hit testing on a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Hit testing results", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Perform hit testing on a dataset. + + Tests retrieval performance for the specified dataset. + """ dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) @@ -13,6 +28,3 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) - - -api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 75a0b18285..9defe6af03 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,10 +1,10 @@ from typing import Literal from flask_login import current_user # type: ignore -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService @@ -14,14 +14,43 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService +# Define parsers for metadata APIs +metadata_create_parser = reqparse.RequestParser() +metadata_create_parser.add_argument( + "type", type=str, required=True, nullable=False, location="json", help="Metadata type" +) +metadata_create_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="Metadata name" +) +metadata_update_parser = reqparse.RequestParser() +metadata_update_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="New metadata name" +) + +document_metadata_parser = reqparse.RequestParser() +document_metadata_parser.add_argument( + "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" +) + + +@service_api_ns.route("/datasets//metadata") class DatasetMetadataCreateServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_create_parser) + @service_api_ns.doc("create_dataset_metadata") + @service_api_ns.doc(description="Create metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 201: "Metadata created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + """Create metadata for a dataset.""" + args = metadata_create_parser.parse_args() metadata_args = MetadataArgs(**args) dataset_id_str = str(dataset_id) @@ -33,7 +62,18 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return marshal(metadata, dataset_metadata_fields), 201 + @service_api_ns.doc("get_dataset_metadata") + @service_api_ns.doc(description="Get all metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): + """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -41,12 +81,23 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): return MetadataService.get_dataset_metadatas(dataset), 200 +@service_api_ns.route("/datasets//metadata/") class DatasetMetadataServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_update_parser) + @service_api_ns.doc("update_dataset_metadata") + @service_api_ns.doc(description="Update metadata name") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + """Update metadata name.""" + args = metadata_update_parser.parse_args() dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -58,8 +109,19 @@ class DatasetMetadataServiceApi(DatasetApiResource): metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 + @service_api_ns.doc("delete_dataset_metadata") + @service_api_ns.doc(description="Delete metadata") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 204: "Metadata deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): + """Delete metadata.""" dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -71,15 +133,37 @@ class DatasetMetadataServiceApi(DatasetApiResource): return 204 +@service_api_ns.route("/datasets/metadata/built-in") class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): + @service_api_ns.doc("get_built_in_fields") + @service_api_ns.doc(description="Get all built-in metadata fields") + @service_api_ns.doc( + responses={ + 200: "Built-in fields retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): + """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() return {"fields": built_in_fields}, 200 +@service_api_ns.route("/datasets//metadata/built-in/") class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): + @service_api_ns.doc("toggle_built_in_field") + @service_api_ns.doc(description="Enable or disable built-in metadata field") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): + """Enable or disable built-in metadata field.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -93,29 +177,31 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): return 200 +@service_api_ns.route("/datasets//documents/metadata") class DocumentMetadataEditServiceApi(DatasetApiResource): + @service_api_ns.expect(document_metadata_parser) + @service_api_ns.doc("update_documents_metadata") + @service_api_ns.doc(description="Update metadata for multiple documents") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Update metadata for multiple documents.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") - args = parser.parse_args() + args = document_metadata_parser.parse_args() metadata_args = MetadataOperationData(**args) MetadataService.update_documents_metadata(dataset, metadata_args) return 200 - - -api.add_resource(DatasetMetadataCreateServiceApi, "/datasets//metadata") -api.add_resource(DatasetMetadataServiceApi, "/datasets//metadata/") -api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in") -api.add_resource( - DatasetMetadataBuiltInFieldActionServiceApi, "/datasets//metadata/built-in/" -) -api.add_resource(DocumentMetadataEditServiceApi, "/datasets//documents/metadata") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 31f862dc8f..f5e2010ca4 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,9 +1,9 @@ from flask import request from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( DatasetApiResource, @@ -19,34 +19,59 @@ from fields.segment_fields import child_chunk_fields, segment_fields from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs -from services.errors.chunk import ( - ChildChunkDeleteIndexError, - ChildChunkIndexingError, -) -from services.errors.chunk import ( - ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError, -) -from services.errors.chunk import ( - ChildChunkIndexingError as ChildChunkIndexingServiceError, -) +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError + +# Define parsers for segment operations +segment_create_parser = reqparse.RequestParser() +segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json") + +segment_list_parser = reqparse.RequestParser() +segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args") +segment_list_parser.add_argument("keyword", type=str, default=None, location="args") + +segment_update_parser = reqparse.RequestParser() +segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") + +child_chunk_create_parser = reqparse.RequestParser() +child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json") + +child_chunk_list_parser = reqparse.RequestParser() +child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args") +child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args") +child_chunk_list_parser.add_argument("page", type=int, default=1, location="args") + +child_chunk_update_parser = reqparse.RequestParser() +child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json") +@service_api_ns.route("/datasets//documents//segments") class SegmentApi(DatasetApiResource): """Resource for segments.""" + @service_api_ns.expect(segment_create_parser) + @service_api_ns.doc("create_segments") + @service_api_ns.doc(description="Create segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments created successfully", + 400: "Bad request - segments data is missing", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str): """Create single segment.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -71,9 +96,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("segments", type=list, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_create_parser.parse_args() if args["segments"] is not None: for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) @@ -82,18 +105,26 @@ class SegmentApi(DatasetApiResource): else: return {"error": "Segments is required"}, 400 - def get(self, tenant_id, dataset_id, document_id): + @service_api_ns.expect(segment_list_parser) + @service_api_ns.doc("list_segments") + @service_api_ns.doc(description="List segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str): """Get segments.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -114,10 +145,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - parser = reqparse.RequestParser() - parser.add_argument("status", type=str, action="append", default=[], location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - args = parser.parse_args() + args = segment_list_parser.parse_args() segments, total = SegmentService.get_segments( document_id=document_id, @@ -140,43 +168,62 @@ class SegmentApi(DatasetApiResource): return response, 200 +@service_api_ns.route("/datasets//documents//segments/") class DatasetSegmentApi(DatasetApiResource): + @service_api_ns.doc("delete_segment") + @service_api_ns.doc(description="Delete a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"} + ) + @service_api_ns.doc( + responses={ + 204: "Segment deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) return 204 + @service_api_ns.expect(segment_update_parser) + @service_api_ns.doc("update_segment") + @service_api_ns.doc(description="Update a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"} + ) + @service_api_ns.doc( + responses={ + 200: "Segment updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") @@ -197,37 +244,39 @@ class DatasetSegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") # validate args - parser = reqparse.RequestParser() - parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( SegmentUpdateArgs(**args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.doc("get_segment") + @service_api_ns.doc(description="Get a specific segment by ID") + @service_api_ns.doc( + responses={ + 200: "Segment retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -235,29 +284,41 @@ class DatasetSegmentApi(DatasetApiResource): return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks" +) class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" + @service_api_ns.expect(child_chunk_create_parser) + @service_api_ns.doc("create_child_chunk") + @service_api_ns.doc(description="Create a new child chunk for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Create child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -280,43 +341,46 @@ class ChildChunkApi(DatasetApiResource): raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_create_parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.expect(child_chunk_list_parser) + @service_api_ns.doc("list_child_chunks") + @service_api_ns.doc(description="List child chunks for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunks retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Get child chunks.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - parser = reqparse.RequestParser() - parser.add_argument("limit", type=int, default=20, location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - parser.add_argument("page", type=int, default=1, location="args") - args = parser.parse_args() + args = child_chunk_list_parser.parse_args() page = args["page"] limit = min(args["limit"], 100) @@ -333,40 +397,63 @@ class ChildChunkApi(DatasetApiResource): }, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks/" +) class DatasetChildChunkApi(DatasetApiResource): """Resource for updating child chunks.""" + @service_api_ns.doc("delete_child_chunk") + @service_api_ns.doc(description="Delete a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to delete", + } + ) + @service_api_ns.doc( + responses={ + 204: "Child chunk deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Delete child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") + # validate segment belongs to the specified document + if segment.document_id != document_id: + raise NotFound("Document not found.") + # check child chunk - child_chunk_id = str(child_chunk_id) child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id ) if not child_chunk: raise NotFound("Child chunk not found.") + # validate child chunk belongs to the specified segment + if child_chunk.segment_id != segment.id: + raise NotFound("Child chunk not found.") + try: SegmentService.delete_child_chunk(child_chunk, dataset) except ChildChunkDeleteIndexServiceError as e: @@ -374,14 +461,30 @@ class DatasetChildChunkApi(DatasetApiResource): return 204 + @service_api_ns.expect(child_chunk_update_parser) + @service_api_ns.doc("update_child_chunk") + @service_api_ns.doc(description="Update a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to update", + } + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Update child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -396,6 +499,10 @@ class DatasetChildChunkApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") + # validate segment belongs to the specified document + if segment.document_id != document_id: + raise NotFound("Segment not found.") + # get child chunk child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id @@ -403,29 +510,16 @@ class DatasetChildChunkApi(DatasetApiResource): if not child_chunk: raise NotFound("Child chunk not found.") + # validate child chunk belongs to the specified segment + if child_chunk.segment_id != segment.id: + raise NotFound("Child chunk not found.") + # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_update_parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - - -api.add_resource(SegmentApi, "/datasets//documents//segments") -api.add_resource( - DatasetSegmentApi, "/datasets//documents//segments/" -) -api.add_resource( - ChildChunkApi, "/datasets//documents//segments//child_chunks" -) -api.add_resource( - DatasetChildChunkApi, - "/datasets//documents//segments//child_chunks/", -) diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 3b4721b5b0..27b36a6402 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -1,6 +1,6 @@ from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import ( DatasetApiResource, ) @@ -11,9 +11,23 @@ from models.model import UploadFile from services.dataset_service import DocumentService +@service_api_ns.route("/datasets//documents//upload-file") class UploadFileApi(DatasetApiResource): + @service_api_ns.doc("get_upload_file") + @service_api_ns.doc(description="Get upload file information and download URL") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Upload file information retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or upload file not found", + } + ) def get(self, tenant_id, dataset_id, document_id): - """Get upload file.""" + """Get upload file information and download URL. + + Returns information about an uploaded file including its download URL. + """ # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -49,6 +63,3 @@ class UploadFileApi(DatasetApiResource): "created_by": upload_file.created_by, "created_at": upload_file.created_at.timestamp(), }, 200 - - -api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 9bb5df4c4e..a9d2d6fadc 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,9 +1,10 @@ -from flask_restful import Resource +from flask_restx import Resource from configs import dify_config -from controllers.service_api import api +from controllers.service_api import service_api_ns +@service_api_ns.route("/") class IndexApi(Resource): def get(self): return { @@ -11,6 +12,3 @@ class IndexApi(Resource): "api_version": "v1", "server_version": dify_config.project.version, } - - -api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 3f18474674..536cf81a2f 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,21 +1,32 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token from core.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService +@service_api_ns.route("/workspaces/current/models/model-types/") class ModelProviderAvailableModelApi(Resource): + @service_api_ns.doc("get_available_models") + @service_api_ns.doc(description="Get available models by model type") + @service_api_ns.doc(params={"model_type": "Type of model to retrieve"}) + @service_api_ns.doc( + responses={ + 200: "Models retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, model_type): + """Get available models by model type. + + Returns a list of available models for the specified model type. + """ tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) return jsonable_encoder({"data": models}) - - -api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index da81cc8bc3..8aac3de4c3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -7,7 +7,7 @@ from typing import Optional from flask import current_app, request from flask_login import user_logged_in # type: ignore -from flask_restful import Resource +from flask_restx import Resource from pydantic import BaseModel from sqlalchemy import select, update from sqlalchemy.orm import Session diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 197859e8f3..0680903635 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,5 +1,7 @@ +import logging + from flask import request -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Unauthorized from controllers.common import fields @@ -87,8 +89,11 @@ class AppWebAuthPermission(Resource): decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") - except Exception as e: - pass + except Unauthorized: + raise + except Exception: + logging.exception("Unexpected error during auth verification") + raise features = FeatureService.get_system_features() if not features.webapp_auth.enabled: diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2919ca9af4..241d0874db 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restx import reqparse try: parser = reqparse.RequestParser() diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd3b9aa804..c19afee9b7 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 98cea3974f..cea8e442f3 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed2238..478b3d2e31 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 0c30435825..b05e2a2e65 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restx import marshal_with import services from controllers.common.errors import ( diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 0da8d65efc..d436657f06 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -2,7 +2,7 @@ import base64 import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 01c4f4a262..d4eafd532b 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from jwt import InvalidTokenError # type: ignore import services diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 7bb81cd0d3..f348221d80 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound from controllers.web import api diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index acd3a8b539..1ac20e6531 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -2,7 +2,7 @@ import uuid from datetime import UTC, datetime, timedelta from flask import request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 4e19716c3d..930b9d96e9 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restx import marshal_with, reqparse import services from controllers.common import helpers diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index d7188ef0b3..a0912499ff 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 3c133499b7..b2a887a0de 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 590fd3f2c7..331587cc28 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ae6f14a689..94fa5d5626 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime from functools import wraps from flask import request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import select from werkzeug.exceptions import BadRequest, NotFound, Unauthorized diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 565fb42478..6cb1077126 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -197,7 +197,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): final_answer = scratchpad.action.action_input else: final_answer = f"{scratchpad.action.action_input}" - except json.JSONDecodeError: + except TypeError: final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 4e6fe60e57..9eb853aa74 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -126,8 +126,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError: - # ensure ascii to avoid encoding error + except TypeError: + # fallback: force ASCII to handle non-serializable objects tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: @@ -153,8 +153,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError: - # ensure ascii to avoid encoding error + except TypeError: + # fallback: force ASCII to handle non-serializable objects tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index fe7abddf87..fdeed43226 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -52,6 +52,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.datetime_utils import naive_utc_now from models import ( Account, CreatorUserRole, @@ -409,7 +410,7 @@ class WorkflowResponseConverter: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), @@ -488,7 +489,7 @@ class WorkflowResponseConverter: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 4d2d590e07..b659bf924d 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,4 +1,3 @@ -import datetime import json import logging from collections import defaultdict @@ -29,6 +28,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.plugin.entities.plugin import ModelProviderID from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, @@ -261,7 +261,7 @@ class ProviderConfiguration(BaseModel): if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + provider_record.updated_at = naive_utc_now() db.session.commit() else: provider_record = Provider() @@ -426,7 +426,7 @@ class ProviderConfiguration(BaseModel): if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + provider_model_record.updated_at = naive_utc_now() db.session.commit() else: provider_model_record = ProviderModel() @@ -501,7 +501,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -526,7 +526,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -599,7 +599,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -638,7 +638,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index afb77d248e..b1680bf395 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -1,5 +1,4 @@ import concurrent.futures -import datetime import json import logging import re @@ -34,6 +33,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -87,7 +87,7 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except ObjectDeletedError: logging.warning("Document deleted, document id: %s", dataset_document.id) @@ -95,7 +95,7 @@ class IndexingRunner: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -150,13 +150,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -225,13 +225,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def indexing_estimate( @@ -401,7 +401,7 @@ class IndexingRunner: after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.parsing_completed_at: naive_utc_now(), }, ) @@ -584,7 +584,7 @@ class IndexingRunner: after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.completed_at: naive_utc_now(), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, }, @@ -609,7 +609,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) @@ -640,7 +640,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) @@ -728,7 +728,7 @@ class IndexingRunner: doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + cur_time = naive_utc_now() self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -743,7 +743,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.indexing_at: naive_utc_now(), }, ) pass diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 0e1277bc86..dc6032e405 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional +from typing import Any, Optional, TypedDict, Union from pydantic import BaseModel, Field @@ -20,6 +20,26 @@ class LLMMode(StrEnum): CHAT = "chat" +class LLMUsageMetadata(TypedDict, total=False): + """ + TypedDict for LLM usage metadata. + All fields are optional. + """ + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_unit_price: Union[float, str] + completion_unit_price: Union[float, str] + total_price: Union[float, str] + currency: str + prompt_price_unit: Union[float, str] + completion_price_unit: Union[float, str] + prompt_price: Union[float, str] + completion_price: Union[float, str] + latency: float + + class LLMUsage(ModelUsage): """ Model class for llm usage. @@ -56,23 +76,27 @@ class LLMUsage(ModelUsage): ) @classmethod - def from_metadata(cls, metadata: dict) -> LLMUsage: + def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: """ Create LLMUsage instance from metadata dictionary with default values. Args: - metadata: Dictionary containing usage metadata + metadata: TypedDict containing usage metadata Returns: LLMUsage instance with values from metadata or defaults """ - total_tokens = metadata.get("total_tokens", 0) + prompt_tokens = metadata.get("prompt_tokens", 0) completion_tokens = metadata.get("completion_tokens", 0) - if total_tokens > 0 and completion_tokens == 0: - completion_tokens = total_tokens + total_tokens = metadata.get("total_tokens", 0) + + # If total_tokens is not provided but prompt and completion tokens are, + # calculate total_tokens + if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): + total_tokens = prompt_tokens + completion_tokens return cls( - prompt_tokens=metadata.get("prompt_tokens", 0), + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index d64f366e0e..112f07844c 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -101,7 +101,7 @@ class MilvusVector(BaseVector): if "Zilliz Cloud" in milvus_version: return True # For standard Milvus installations, check version number - return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version + return version.parse(milvus_version) >= version.parse("2.5.0") except Exception as e: logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e)) return False diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 8efe105bbf..556d03940e 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -152,7 +152,7 @@ class OceanBaseVector(BaseVector): ob_full_version = result.fetchone()[0] ob_version = ob_full_version.split()[1] logger.debug("Current OceanBase version is %s", ob_version) - return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version + return version.parse(ob_version) >= version.parse("4.3.5.1") except Exception as e: logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e)) return False diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index a3b35458df..7cc554c74d 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -34,9 +34,8 @@ class ExcelExtractor(BaseExtractor): for sheet_name in wb.sheetnames: sheet = wb[sheet_name] data = sheet.values - try: - cols = next(data) - except StopIteration: + cols = next(data, None) + if cols is None: continue df = pd.DataFrame(data, columns=cols) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0eff7c186a..f3b162e3d3 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,6 +1,5 @@ """Abstract interface for document loader implementations.""" -import datetime import logging import mimetypes import os @@ -19,6 +18,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -117,10 +117,10 @@ class WordExtractor(BaseExtractor): mime_type=mime_type or "", created_by=self.user_id, created_by_role=CreatorUserRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used_at=naive_utc_now(), ) db.session.add(upload_file) diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 09c775f3a6..854c122331 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,10 +5,7 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -import importlib -import inspect -import logging -from typing import Protocol, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,12 +13,11 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom -logger = logging.getLogger(__name__) - class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" @@ -37,96 +33,6 @@ class DifyCoreRepositoryFactory: are specified as module paths (e.g., 'module.submodule.ClassName'). """ - @staticmethod - def _import_class(class_path: str) -> type: - """ - Import a class from a module path string. - - Args: - class_path: Full module path to the class (e.g., 'module.submodule.ClassName') - - Returns: - The imported class - - Raises: - RepositoryImportError: If the class cannot be imported - """ - try: - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - repo_class = getattr(module, class_name) - assert isinstance(repo_class, type) - return repo_class - except (ValueError, ImportError, AttributeError) as e: - raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e - - @staticmethod - def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore - """ - Validate that a class implements the expected repository interface. - - Args: - repository_class: The class to validate - expected_interface: The expected interface/protocol - - Raises: - RepositoryImportError: If the class doesn't implement the interface - """ - # Check if the class has all required methods from the protocol - required_methods = [ - method - for method in dir(expected_interface) - if not method.startswith("_") and callable(getattr(expected_interface, method, None)) - ] - - missing_methods = [] - for method_name in required_methods: - if not hasattr(repository_class, method_name): - missing_methods.append(method_name) - - if missing_methods: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' does not implement required methods " - f"{missing_methods} from interface '{expected_interface.__name__}'" - ) - - @staticmethod - def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: - """ - Validate that a repository class constructor accepts required parameters. - Args: - repository_class: The class to validate - required_params: List of required parameter names - Raises: - RepositoryImportError: If the constructor doesn't accept required parameters - """ - - try: - # MyPy may flag the line below with the following error: - # - # > Accessing "__init__" on an instance is unsound, since - # > instance.__init__ could be from an incompatible subclass. - # - # Despite this, we need to ensure that the constructor of `repository_class` - # has a compatible signature. - signature = inspect.signature(repository_class.__init__) # type: ignore[misc] - param_names = list(signature.parameters.keys()) - - # Remove 'self' parameter - if "self" in param_names: - param_names.remove("self") - - missing_params = [param for param in required_params if param not in param_names] - if missing_params: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " - f"{missing_params}. Expected parameters: {required_params}" - ) - except Exception as e: - raise RepositoryImportError( - f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" - ) from e - @classmethod def create_workflow_execution_repository( cls, @@ -151,24 +57,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY - logger.debug("Creating WorkflowExecutionRepository from: %s", class_path) try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) - - # All repository types now use the same constructor parameters + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e @classmethod @@ -195,24 +93,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY - logger.debug("Creating WorkflowNodeExecutionRepository from: %s", class_path) try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) - - # All repository types now use the same constructor parameters + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index 9d70dd0ab6..354d77673c 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -6,12 +6,14 @@ implementation details like tenant_id, app_id, etc. """ from collections.abc import Mapping -from datetime import UTC, datetime +from datetime import datetime from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel, Field +from libs.datetime_utils import naive_utc_now + class WorkflowType(StrEnum): """ @@ -61,7 +63,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) + end_time = self.finished_at or naive_utc_now() return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index f2d9c98936..a4ddfafab5 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,5 +1,5 @@ import uuid -from datetime import UTC, datetime +from datetime import datetime from enum import Enum from typing import Optional @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from libs.datetime_utils import naive_utc_now class RouteNodeState(BaseModel): @@ -71,7 +72,7 @@ class RouteNodeState(BaseModel): raise Exception(f"Invalid route status {run_result.status}") self.node_run_result = run_result - self.finished_at = datetime.now(UTC).replace(tzinfo=None) + self.finished_at = naive_utc_now() class RuntimeRouteState(BaseModel): @@ -89,7 +90,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ - state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) + state = RouteNodeState(node_id=node_id, start_at=naive_utc_now()) self.node_state_mapping[state.id] = state return state diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e30853b760..833e118388 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,7 +6,6 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app @@ -51,6 +50,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from libs.datetime_utils import naive_utc_now from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -640,7 +640,7 @@ class GraphEngine: while should_continue_retry and retries <= max_retries: try: # run node - retry_start_at = datetime.now(UTC).replace(tzinfo=None) + retry_start_at = naive_utc_now() # yield control to other threads time.sleep(0.001) event_stream = node.run() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 2b6382a8a6..144f036aa4 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -13,8 +13,9 @@ from core.agent.strategy.plugin import PluginAgentStrategy from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller @@ -558,7 +559,7 @@ class AgentNode(BaseNode): assert isinstance(message.message, ToolInvokeMessage.JsonMessage) if node_type == NodeType.AGENT: msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) agent_execution_metadata = { WorkflowNodeExecutionMetadataKey(key): value for key, value in msg_metadata.items() @@ -692,7 +693,13 @@ class AgentNode(BaseNode): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + **variables, + }, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index c9f7fa1221..a5a578a6ff 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -12,6 +12,7 @@ from json_repair import repair_json from configs import dify_config from core.file import file_manager +from core.file.enums import FileTransferMethod from core.helper import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.entities.variable_pool import VariablePool @@ -228,7 +229,9 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None: + if file.related_id is not None or ( + file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None + ): file_tuple = ( file.filename, file_manager.download(file), diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index def1e1cfa3..7f591a3ea9 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -4,7 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait -from datetime import UTC, datetime +from datetime import datetime from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Optional, cast @@ -41,6 +41,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment +from libs.datetime_utils import naive_utc_now from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -179,7 +180,7 @@ class IterationNode(BaseNode): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(UTC).replace(tzinfo=None) + start_at = naive_utc_now() yield IterationRunStartedEvent( iteration_id=self.id, @@ -428,7 +429,7 @@ class IterationNode(BaseNode): """ run single iteration """ - iter_start_at = datetime.now(UTC).replace(tzinfo=None) + iter_start_at = naive_utc_now() try: rst = graph_engine.run() @@ -505,7 +506,7 @@ class IterationNode(BaseNode): variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -526,7 +527,7 @@ class IterationNode(BaseNode): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -602,7 +603,7 @@ class IterationNode(BaseNode): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0966c87a1d..2441e30c87 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from datetime import UTC, datetime from typing import Optional, cast from sqlalchemy import select, update @@ -20,6 +19,7 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.llm.entities import ModelConfig +from libs.datetime_utils import naive_utc_now from models import db from models.model import Conversation from models.provider import Provider, ProviderType @@ -149,7 +149,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs ) .values( quota_used=Provider.quota_used + used_quota, - last_used=datetime.now(tz=UTC).replace(tzinfo=None), + last_used=naive_utc_now(), ) ) session.execute(stmt) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 9a288c6133..b2ab943129 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import json import logging import time from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config @@ -36,6 +36,7 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.datetime_utils import naive_utc_now if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -143,7 +144,7 @@ class LoopNode(BaseNode): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(UTC).replace(tzinfo=None) + start_at = naive_utc_now() condition_processor = ConditionProcessor() # Start Loop event @@ -171,7 +172,7 @@ class LoopNode(BaseNode): try: check_break_result = False for i in range(loop_count): - loop_start_time = datetime.now(UTC).replace(tzinfo=None) + loop_start_time = naive_utc_now() # run single loop loop_result = yield from self._run_single_loop( graph_engine=graph_engine, @@ -185,7 +186,7 @@ class LoopNode(BaseNode): start_at=start_at, inputs=inputs, ) - loop_end_time = datetime.now(UTC).replace(tzinfo=None) + loop_end_time = naive_utc_now() single_loop_variable = {} for key, selector in loop_variable_selectors.items(): diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index da147fe895..e21092349e 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage} + -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 00a66f50ad..bbc913b7cf 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -8,4 +8,6 @@ def handle(sender, **kwargs): dataset_id = kwargs.get("dataset_id") doc_form = kwargs.get("doc_form") file_id = kwargs.get("file_id") + assert dataset_id is not None + assert doc_form is not None clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 2ed42c71ea..f01dd58900 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -188,7 +188,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] # Use SQLAlchemy's context manager for transaction management # This automatically handles commit/rollback - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): # Use a single transaction for all updates for update_operation in updates_to_perform: filters = update_operation.filters diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index a4d013ffc0..1024fd9ce6 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -29,7 +29,6 @@ def init_app(app: DifyApp): methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) - app.register_blueprint(web_bp) CORS( @@ -40,10 +39,13 @@ def init_app(app: DifyApp): methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) - app.register_blueprint(console_app_bp) - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + CORS( + files_bp, + allow_headers=["Content-Type"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 00e0bd9a16..fb5352ca8f 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -3,8 +3,8 @@ from datetime import timedelta from typing import Any, Optional import pytz -from celery import Celery, Task # type: ignore -from celery.schedules import crontab # type: ignore +from celery import Celery, Task +from celery.schedules import crontab from configs import dify_config from dify_app import DifyApp @@ -66,7 +66,6 @@ def init_app(app: DifyApp) -> Celery: task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, - task_ignore_result=True, ) celery_app.conf.update( @@ -77,6 +76,7 @@ def init_app(app: DifyApp) -> Celery: worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), + task_ignore_result=True, ) # Apply SSL configuration if enabled diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 9b18e25eaa..9e5c71fb1d 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -20,6 +20,10 @@ login_manager = flask_login.LoginManager() @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" + # Skip authentication for documentation endpoints + if request.path.endswith("/docs") or request.path.endswith("/swagger.json"): + return None + auth_header = request.headers.get("Authorization", "") auth_token: str | None = None if auth_header: diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 3fd9633e79..544a2dc625 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -8,7 +8,7 @@ import sys from typing import Union import flask -from celery.signals import worker_init # type: ignore +from celery.signals import worker_init from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16..38835d5ac7 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,12 @@ annotation_fields = { # 'account': fields.Nested(simple_account_fields, allow_null=True) } + +def build_annotation_model(api_or_ns: Api | Namespace): + """Build the annotation model for the API or Namespace.""" + return api_or_ns.model("Annotation", annotation_fields) + + annotation_list_fields = { "data": fields.List(fields.Nested(annotation_fields)), } diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34db..a2dda1dc15 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,10 +1,10 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField class HiddenAPIKey(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 1a5fcabf97..1f14d663b8 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,6 +1,6 @@ import json -from flask_restful import fields +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 370e8a5a58..ecc267cf38 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -45,6 +45,12 @@ message_file_fields = { "upload_file_id": fields.String(default=None), } + +def build_message_file_model(api_or_ns: Api | Namespace): + """Build the message file fields for the API or Namespace.""" + return api_or_ns.model("MessageFile", message_file_fields) + + agent_thought_fields = { "id": fields.String, "chain_id": fields.String, @@ -209,3 +215,22 @@ conversation_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(simple_conversation_fields)), } + + +def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation infinite scroll pagination model for the API or Namespace.""" + simple_conversation_model = build_simple_conversation_model(api_or_ns) + + copied_fields = conversation_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) + return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) + + +def build_conversation_delete_model(api_or_ns: Api | Namespace): + """Build the conversation delete model for the API or Namespace.""" + return api_or_ns.model("ConversationDelete", conversation_delete_fields) + + +def build_simple_conversation_model(api_or_ns: Api | Namespace): + """Build the simple conversation model for the API or Namespace.""" + return api_or_ns.model("SimpleConversation", simple_conversation_fields) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c5a0c9a49d..7d5e311591 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -27,3 +27,19 @@ conversation_variable_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(conversation_variable_fields)), } + + +def build_conversation_variable_model(api_or_ns: Api | Namespace): + """Build the conversation variable model for the API or Namespace.""" + return api_or_ns.model("ConversationVariable", conversation_variable_fields) + + +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" + # Build the nested variable model first + conversation_variable_model = build_conversation_variable_model(api_or_ns) + + copied_fields = conversation_variable_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(conversation_variable_model)) + + return api_or_ns.model("ConversationVariableInfiniteScrollPagination", copied_fields) diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376f..93f6e447dc 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 79a4f1c6de..f639fb2ea9 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 7fd43e8dbe..9be59f7454 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1..ea43e3b5fd 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -6,3 +6,7 @@ simple_end_user_fields = { "is_anonymous": fields.Boolean, "session_id": fields.String, } + + +def build_simple_end_user_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 8b4839ef97..dd359e2f5f 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,19 @@ upload_config_fields = { "workflow_file_upload_limit": fields.Integer, } + +def build_upload_config_model(api_or_ns: Api | Namespace): + """Build the upload config model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("UploadConfig", upload_config_fields) + + file_fields = { "id": fields.String, "name": fields.String, @@ -22,12 +35,37 @@ file_fields = { "preview_url": fields.String, } + +def build_file_model(api_or_ns: Api | Namespace): + """Build the file model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("File", file_fields) + + remote_file_info_fields = { "file_type": fields.String(attribute="file_type"), "file_length": fields.Integer(attribute="file_length"), } +def build_remote_file_info_model(api_or_ns: Api | Namespace): + """Build the remote file info model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) + + file_fields_with_signed_url = { "id": fields.String, "name": fields.String, @@ -38,3 +76,15 @@ file_fields_with_signed_url = { "created_by": fields.String, "created_at": TimestampField, } + + +def build_file_with_signed_url_model(api_or_ns: Api | Namespace): + """Build the file with signed URL model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 9d67999ea4..75bdff1803 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f6..16dd26a10e 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 8007b7e052..08e38a6931 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,8 +1,17 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import AvatarUrlField, TimestampField -simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} +simple_account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, +} + + +def build_simple_account_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleAccount", simple_account_fields) + account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index e6aebd810f..a419da2e18 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,11 +1,19 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField from .raws import FilesContainedField -feedback_fields = {"rating": fields.String} +feedback_fields = { + "rating": fields.String, +} + + +def build_feedback_model(api_or_ns: Api | Namespace): + """Build the feedback model for the API or Namespace.""" + return api_or_ns.model("Feedback", feedback_fields) + agent_thought_fields = { "id": fields.String, @@ -21,6 +29,12 @@ agent_thought_fields = { "files": fields.List(fields.String), } + +def build_agent_thought_model(api_or_ns: Api | Namespace): + """Build the agent thought model for the API or Namespace.""" + return api_or_ns.model("AgentThought", agent_thought_fields) + + retriever_resource_fields = { "id": fields.String, "message_id": fields.String, diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13..9bc6a12c78 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 4126c24598..2ff917d6bc 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd..d5b7c86a04 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,12 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields -tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} +dataset_tag_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "binding_count": fields.String, +} + + +def build_dataset_tag_fields(api_or_ns: Api | Namespace): + return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 823c99ec6b..243efd817c 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,8 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields -from fields.workflow_run_fields import workflow_run_for_log_fields +from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields +from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -15,6 +15,24 @@ workflow_app_log_partial_fields = { "created_at": TimestampField, } + +def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): + """Build the workflow app log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_app_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowAppLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -22,3 +40,13 @@ workflow_app_log_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(workflow_app_log_partial_fields)), } + + +def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): + """Build the workflow app log pagination model for the API or Namespace.""" + # Build the nested partial model first + workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) + + copied_fields = workflow_app_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) + return api_or_ns.model("WorkflowAppLogPagination", copied_fields) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9207ad7ab7..1f7185618c 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index a106728e9c..6462d8ce5a 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -17,6 +17,11 @@ workflow_run_for_log_fields = { "exceptions_count": fields.Integer, } + +def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): + return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 2070df3e55..95d13cd0e6 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,119 +1,111 @@ import re import sys +from collections.abc import Mapping from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message -from werkzeug.datastructures import Headers +from flask_restx import Api from werkzeug.exceptions import HTTPException +from werkzeug.http import HTTP_STATUS_CODES from core.errors.error import AppInvokeQuotaExceededError -class ExternalApi(Api): - def handle_error(self, e): - """Error handler for the API transforms a raised exception into a Flask - response, with the appropriate HTTP status code and body. +def http_status_message(code): + return HTTP_STATUS_CODES.get(code, "") - :param e: the raised Exception object - :type e: Exception - """ +def register_external_error_handlers(api: Api) -> None: + @api.errorhandler(HTTPException) + def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) - headers = Headers() - if isinstance(e, HTTPException): - if e.response is not None: - resp = e.get_response() - return resp + # If Werkzeug already prepared a Response, just use it. + if getattr(e, "response", None) is not None: + return e.response - status_code = e.code - default_data = { - "code": re.sub(r"(?= 500: - exc_info: Any = sys.exc_info() - if exc_info[1] is None: - exc_info = None - current_app.log_exception(exc_info) - - if status_code == 406 and self.default_mediatype is None: - # if we are handling NotAcceptable (406), make sure that - # make_response uses a representation we support as the - # default mediatype (so that make_response doesn't throw - # another NotAcceptable error). - supported_mediatypes = list(self.representations.keys()) # only supported application/json - fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" - data = {"code": "not_acceptable", "message": data.get("message")} - resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) + # Payload per status + if status_code == 406 and api.default_mediatype is None: + data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code} + return data, status_code, headers elif status_code == 400: - if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message", {}).items())[0] - data = {"code": "invalid_param", "message": param_value, "params": param_key} + msg = default_data["message"] + if isinstance(msg, Mapping) and msg: + # Convert param errors like {"field": "reason"} into a friendly shape + param_key, param_value = next(iter(msg.items())) + data = { + "code": "invalid_param", + "message": str(param_value), + "params": param_key, + "status": status_code, + } else: - if "code" not in data: - data["code"] = "unknown" - - resp = self.make_response(data, status_code, headers) + data = {**default_data} + data.setdefault("code", "unknown") + return data, status_code, headers else: - if "code" not in data: - data["code"] = "unknown" + data = {**default_data} + data.setdefault("code", "unknown") + # If you need WWW-Authenticate for 401, add it to headers + if status_code == 401: + headers["WWW-Authenticate"] = 'Bearer realm="api"' + return data, status_code, headers - resp = self.make_response(data, status_code, headers) + @api.errorhandler(ValueError) + def handle_value_error(e: ValueError): + got_request_exception.send(current_app, exception=e) + status_code = 400 + data = {"code": "invalid_param", "message": str(e), "status": status_code} + return data, status_code - if status_code == 401: - resp = self.unauthorized(resp) - return resp + @api.errorhandler(AppInvokeQuotaExceededError) + def handle_quota_exceeded(e: AppInvokeQuotaExceededError): + got_request_exception.send(current_app, exception=e) + status_code = 429 + data = {"code": "too_many_requests", "message": str(e), "status": status_code} + return data, status_code + + @api.errorhandler(Exception) + def handle_general_exception(e: Exception): + got_request_exception.send(current_app, exception=e) + + status_code = 500 + data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + + # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) + if not isinstance(data, Mapping): + data = {"message": str(e)} + + data.setdefault("code", "unknown") + data.setdefault("status", status_code) + + # Log stack + exc_info: Any = sys.exc_info() + if exc_info[1] is None: + exc_info = None + current_app.log_exception(exc_info) + + return data, status_code + + +class ExternalApi(Api): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + register_external_error_handlers(self) diff --git a/api/libs/helper.py b/api/libs/helper.py index b36f972e19..70986fedd3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restx import fields from pydantic import BaseModel from configs import dify_config @@ -57,7 +57,7 @@ def run(script): class AppIconUrlField(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): if obj is None: return None @@ -72,7 +72,7 @@ class AppIconUrlField(fields.Raw): class AvatarUrlField(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): if obj is None: return None diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py new file mode 100644 index 0000000000..616d072a1b --- /dev/null +++ b/api/libs/module_loading.py @@ -0,0 +1,55 @@ +""" +Module loading utilities similar to Django's module_loading. + +Reference implementation from Django: +https://github.com/django/django/blob/main/django/utils/module_loading.py +""" + +import sys +from importlib import import_module +from typing import Any + + +def cached_import(module_path: str, class_name: str) -> Any: + """ + Import a module and return the named attribute/class from it, with caching. + + Args: + module_path: The module path to import from + class_name: The attribute/class name to retrieve + + Returns: + The imported attribute/class + """ + if not ( + (module := sys.modules.get(module_path)) + and (spec := getattr(module, "__spec__", None)) + and getattr(spec, "_initializing", False) is False + ): + module = import_module(module_path) + return getattr(module, class_name) + + +def import_string(dotted_path: str) -> Any: + """ + Import a dotted module path and return the attribute/class designated by + the last name in the path. Raise ImportError if the import failed. + + Args: + dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class or attribute + + Raises: + ImportError: If the module or attribute cannot be imported + """ + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as err: + raise ImportError(f"{dotted_path} doesn't look like a module path") from err + + try: + return cached_import(module_path, class_name) + except AttributeError as err: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err diff --git a/api/models/task.py b/api/models/task.py index ab700c553c..9a52fcfb41 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional import sqlalchemy as sa -from celery import states # type: ignore +from celery import states from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column diff --git a/api/mypy.ini b/api/mypy.ini index 3a6a54afe1..44a01068e9 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -12,8 +12,11 @@ exclude = (?x)( [mypy-flask_login] ignore_missing_imports=True -[mypy-flask_restful] +[mypy-flask_restx] ignore_missing_imports=True -[mypy-flask_restful.inputs] +[mypy-flask_restx.api] +ignore_missing_imports=True + +[mypy-flask_restx.inputs] ignore_missing_imports=True diff --git a/api/pyproject.toml b/api/pyproject.toml index cf5ad8e7d2..6aa4746d2f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "flask-login~=0.6.3", "flask-migrate~=4.0.7", "flask-orjson~=2.0.0", - "flask-restful~=0.3.10", "flask-sqlalchemy~=3.1.1", "gevent~=24.11.1", "gmpy2~=2.2.1", @@ -88,6 +87,7 @@ dependencies = [ "sseclient-py>=1.8.0", "httpx-sse>=0.4.0", "sendgrid~=6.12.3", + "flask-restx>=1.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -110,7 +110,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=32.1.0", "lxml-stubs~=0.5.1", - "mypy~=1.16.0", + "mypy~=1.17.1", "ruff~=0.12.3", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", @@ -164,6 +164,7 @@ dev = [ "scipy-stubs>=1.15.3.0", "types-python-http-client>=3.3.7.20240910", "types-redis>=4.6.0.20241004", + "celery-types>=0.23.0", ] ############################################################ diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 1f0320054c..0be9c8908c 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -5,17 +5,14 @@ This factory is specifically designed for DifyAPI repositories that handle service-layer operations with dependency injection patterns. """ -import logging - from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from libs.module_loading import import_string from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository -logger = logging.getLogger(__name__) - class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): """ @@ -50,17 +47,9 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e @@ -87,15 +76,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create APIWorkflowRunRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/services/account_service.py b/api/services/account_service.py index 1cce8e67a4..0bb903fbbc 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -425,7 +425,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", ): account_email = account.email if account else email if account_email is None: @@ -452,12 +452,14 @@ class AccountService: account: Optional[Account] = None, email: Optional[str] = None, old_email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", phase: Optional[str] = None, ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + if not phase: + raise ValueError("phase must be provided.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError @@ -480,7 +482,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", ): account_email = account.email if account else email if account_email is None: @@ -496,7 +498,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", ): account_email = account.email if account else email @@ -509,6 +511,7 @@ class AccountService: raise OwnerTransferRateLimitExceededError() code, token = cls.generate_owner_transfer_token(account_email, account) + workspace_name = workspace_name or "" send_owner_transfer_confirm_task.delay( language=language, @@ -524,13 +527,14 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", - new_owner_email: Optional[str] = "", + new_owner_email: str = "", ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + workspace_name = workspace_name or "" send_old_owner_transfer_notify_email_task.delay( language=language, @@ -544,12 +548,13 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + workspace_name = workspace_name or "" send_new_owner_transfer_notify_email_task.delay( language=language, @@ -633,7 +638,10 @@ class AccountService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: str = "en-US", ): email = account.email if account else email if email is None: @@ -1260,10 +1268,11 @@ class RegisterService: raise AccountAlreadyInTenantError("Account already in tenant.") token = cls.generate_invite_token(tenant, account) + language = account.interface_language or "en-US" # send email send_invite_member_mail_task.delay( - language=account.interface_language, + language=language, to=email, token=token, inviter_name=inviter.name if inviter else "Dify", diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 1a0fdfa420..45b246af1e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,4 +1,3 @@ -import datetime import uuid from typing import cast @@ -10,6 +9,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task @@ -473,7 +473,7 @@ class AppAnnotationService: raise NotFound("App annotation not found") annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + annotation_setting.updated_at = naive_utc_now() db.session.add(annotation_setting) db.session.commit() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 712ef4c601..ac603d3cc9 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,4 +1,5 @@ import contextlib +import logging from collections.abc import Callable, Sequence from typing import Any, Optional, Union @@ -23,6 +24,9 @@ from services.errors.conversation import ( LastConversationNotExistsError, ) from services.errors.message import MessageNotExistsError +from tasks.delete_conversation_task import delete_conversation_related_data + +logger = logging.getLogger(__name__) class ConversationService: @@ -175,11 +179,21 @@ class ConversationService: @classmethod def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = cls.get_conversation(app_model, conversation_id, user) + try: + logger.info( + "Initiating conversation deletion for app_name %s, conversation_id: %s", + app_model.name, + conversation_id, + ) - conversation.is_deleted = True - conversation.updated_at = naive_utc_now() - db.session.commit() + db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) + db.session.commit() + + delete_conversation_related_data.delay(conversation_id) + + except Exception as e: + db.session.rollback() + raise e @classmethod def get_conversational_variable( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 79f8efa360..859900c009 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1424,7 +1424,7 @@ class DocumentService: ) if document: document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = knowledge_config.doc_form document.doc_language = knowledge_config.doc_language @@ -1993,7 +1993,7 @@ class DocumentService: document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = document_data.doc_form db.session.add(document) @@ -2353,7 +2353,7 @@ class DocumentService: Returns: dict: Update information or None if no update needed """ - now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + now = naive_utc_now() if action == "enable": return DocumentService._prepare_enable_update(document, now) @@ -2481,8 +2481,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -2502,7 +2502,7 @@ class SegmentService: except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.disabled_at = naive_utc_now() segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -2558,8 +2558,8 @@ class SegmentService: tokens=tokens, keywords=segment_item.get("keywords", []), status="completed", - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -2586,7 +2586,7 @@ class SegmentService: logging.exception("create segment index failed") for segment_document in segment_data_list: segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.disabled_at = naive_utc_now() segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -2603,7 +2603,7 @@ class SegmentService: if segment.enabled != action: if not action: segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id db.session.add(segment) db.session.commit() @@ -2701,10 +2701,10 @@ class SegmentService: segment.word_count = len(content) segment.tokens = tokens segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.indexing_at = naive_utc_now() + segment.completed_at = naive_utc_now() segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.updated_at = naive_utc_now() segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -2757,7 +2757,7 @@ class SegmentService: except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() @@ -2785,13 +2785,9 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - # Check if segment_ids is not empty to avoid WHERE false condition - if not segment_ids or len(segment_ids) == 0: - return - index_node_ids = ( - db.session.query(DocumentSegment) - .with_entities(DocumentSegment.index_node_id) - .where( + segments = ( + db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + .filter( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2799,7 +2795,15 @@ class SegmentService: ) .all() ) - index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + if not segments: + return + + index_node_ids = [seg.index_node_id for seg in segments] + total_words = sum(seg.word_count for seg in segments) + + document.word_count -= total_words + db.session.add(document) delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() @@ -2859,7 +2863,7 @@ class SegmentService: if cache_result is not None: continue segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id db.session.add(segment) real_deal_segment_ids.append(segment.id) @@ -2949,7 +2953,7 @@ class SegmentService: child_chunk.content = child_chunk_update_args.content child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id - child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + child_chunk.updated_at = naive_utc_now() child_chunk.type = "customized" update_child_chunks.append(child_chunk) else: @@ -3006,7 +3010,7 @@ class SegmentService: child_chunk.content = content child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id - child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + child_chunk.updated_at = naive_utc_now() child_chunk.type = "customized" db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) diff --git a/api/services/file_service.py b/api/services/file_service.py index e234c2f325..4c0a0f451c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,4 +1,3 @@ -import datetime import hashlib import os import uuid @@ -18,6 +17,7 @@ from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models.account import Account from models.enums import CreatorUserRole @@ -80,7 +80,7 @@ class FileService: mime_type=mimetype, created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, @@ -131,10 +131,10 @@ class FileService: mime_type="text/plain", created_by=current_user.id, created_by_role=CreatorUserRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used_at=naive_utc_now(), ) db.session.add(upload_file) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 2a83588f41..fd222f59d3 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,5 +1,4 @@ import copy -import datetime import logging from typing import Optional @@ -8,6 +7,7 @@ from flask_login import current_user from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -69,7 +69,7 @@ class MetadataService: old_name = metadata.name metadata.name = name metadata.updated_by = current_user.id - metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + metadata.updated_at = naive_utc_now() # update related documents dataset_metadata_bindings = ( diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index fe28aa006e..f8dd70c790 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,4 +1,3 @@ -import datetime import json import logging from json import JSONDecodeError @@ -17,6 +16,7 @@ from core.model_runtime.entities.provider_entities import ( from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.provider import LoadBalancingModelConfig logger = logging.getLogger(__name__) @@ -371,7 +371,7 @@ class ModelLoadBalancingService: load_balancing_config.name = name load_balancing_config.enabled = enabled - load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + load_balancing_config.updated_at = naive_utc_now() db.session.commit() self._clear_credentials_cache(tenant_id, config_id) diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index a9df8d0d73..8d21335c86 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -63,7 +63,7 @@ class WebAppAuthService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US" ): email = account.email if account else email if email is None: diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index b52f4924ba..9f01bcb668 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -1,5 +1,4 @@ import dataclasses -import datetime import logging from collections.abc import Mapping, Sequence from enum import StrEnum @@ -23,6 +22,7 @@ from core.workflow.nodes.variable_assigner.common.helpers import get_updated_var from core.workflow.variable_loader import VariableLoader from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from libs.datetime_utils import naive_utc_now from models import App, Conversation from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable @@ -231,7 +231,7 @@ class WorkflowDraftVariableService: variable.set_name(name) if value is not None: variable.set_value(value) - variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + variable.last_edited_at = naive_utc_now() self._session.flush() return variable diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index c5ee4ce3f9..8834229e16 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument @@ -95,7 +95,7 @@ def add_document_to_index_task(dataset_document_id: str): DocumentSegment.enabled: True, DocumentSegment.disabled_at: None, DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.updated_at: naive_utc_now(), } ) db.session.commit() @@ -107,7 +107,7 @@ def add_document_to_index_task(dataset_document_id: str): except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.disabled_at = naive_utc_now() dataset_document.indexing_status = "error" dataset_document.error = str(e) db.session.commit() diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index e436f00133..5bf8e7c33e 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 47dc3ee90e..fd33feea16 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f016400e16..1894031a80 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 0076113ce8..a8375dfa26 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 44c65c0783..9ffaf81af6 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService @@ -72,7 +72,7 @@ def enable_annotation_reply_task( annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + annotation_setting.updated_at = naive_utc_now() db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 5f11d5aa00..337434b768 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index e64a799146..ed47b62e1b 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dee43cd854..50293f38a7 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -1,4 +1,3 @@ -import datetime import logging import tempfile import time @@ -7,7 +6,7 @@ from pathlib import Path import click import pandas as pd -from celery import shared_task # type: ignore +from celery import shared_task from sqlalchemy import func from sqlalchemy.orm import Session @@ -17,6 +16,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from models.model import UploadFile from services.vector_service import VectorService @@ -123,9 +123,9 @@ def batch_create_segment_to_index_task( word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), status="completed", - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=naive_utc_now(), ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 7b940847c9..3d3fadbd0a 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 5479ba8e8f..c18329a9c2 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ import time from typing import Optional import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index bf1a92f038..3ad6257cda 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 543a512851..db2f69596d 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time from typing import Optional import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] db.session.query(DocumentSegment).filter_by(id=segment.id).update( { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.indexing_at: naive_utc_now(), } ) db.session.commit() @@ -79,7 +79,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] db.session.query(DocumentSegment).filter_by(id=segment.id).update( { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) db.session.commit() @@ -89,7 +89,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 5ab377c232..512ea1048a 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,7 +3,7 @@ import time from typing import Literal import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index ef50adf8d5..29f5a2450d 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -1,6 +1,6 @@ import logging -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_database import db from models.account import Account diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py new file mode 100644 index 0000000000..4279dd2c17 --- /dev/null +++ b/api/tasks/delete_conversation_task.py @@ -0,0 +1,68 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from extensions.ext_database import db +from models import ConversationVariable +from models.model import Message, MessageAnnotation, MessageFeedback +from models.tools import ToolConversationVariables, ToolFile +from models.web import PinnedConversation + + +@shared_task(queue="conversation") +def delete_conversation_related_data(conversation_id: str) -> None: + """ + Delete related data conversation in correct order from datatbase to respect foreign key constraints + + Args: + conversation_id: conversation Id + """ + + logging.info( + click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green") + ) + start_at = time.perf_counter() + + try: + db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(ToolConversationVariables).where( + ToolConversationVariables.conversation_id == conversation_id + ).delete(synchronize_session=False) + + db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + + db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) + + db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logging.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) + db.session.rollback() + raise e + finally: + db.session.close() diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index da12355d23..f091085fb8 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index fa4ec15f8a..c813a9dca6 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index f033f05084..252321ba83 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 484f225c22..6d79ce8a5f 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -57,7 +57,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.commit() # delete all document segment and index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 728db2e2dc..c414b01d0e 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 053c0c5f41..31bbc8b570 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -1,13 +1,13 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -31,7 +31,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): return document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.commit() # delete all document segment and index diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index faa7e2b8d0..f3850b7e3b 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -55,7 +55,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() return @@ -86,7 +86,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index f801c9d9ee..a4bcc043e3 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -89,7 +89,7 @@ def enable_segment_to_index_task(segment_id: str): except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 777380631f..1db984f0d3 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -103,7 +103,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i { "error": str(e), "status": "error", - "disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + "disabled_at": naive_utc_now(), "enabled": False, } ) diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py index 38b5ca1800..43ddbfc03b 100644 --- a/api/tasks/mail_account_deletion_task.py +++ b/api/tasks/mail_account_deletion_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py index 054053558d..a56109705a 100644 --- a/api/tasks/mail_change_mail_task.py +++ b/api/tasks/mail_change_mail_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index a82ab55384..53ea3709cd 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service diff --git a/api/tasks/mail_inner_task.py b/api/tasks/mail_inner_task.py index 101f7ebaa4..cad4657bc8 100644 --- a/api/tasks/mail_inner_task.py +++ b/api/tasks/mail_inner_task.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping import click -from celery import shared_task # type: ignore +from celery import shared_task from flask import render_template_string from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index ff351f08af..f4f7f58416 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from extensions.ext_mail import mail diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py index 3856bf294a..db7158e786 100644 --- a/api/tasks/mail_owner_transfer_task.py +++ b/api/tasks/mail_owner_transfer_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index b01af7827b..066d648530 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index c7e0047664..a4ef60b13c 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task # type: ignore +from celery import shared_task from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 9ea6aa6214..ec0b534546 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -2,7 +2,7 @@ import traceback import typing import click -from celery import shared_task # type: ignore +from celery import shared_task from core.helper import marketplace from core.helper.marketplace import MarketplacePluginDeclaration diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index ff489340cd..998fc6b32d 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 828c52044f..3d623c09d1 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -4,7 +4,7 @@ from collections.abc import Callable import click import sqlalchemy as sa -from celery import shared_task # type: ignore +from celery import shared_task from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -370,8 +370,8 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: with db.engine.begin() as conn: # Get a batch of draft variable IDs query_sql = """ - SELECT id FROM workflow_draft_variables - WHERE app_id = :app_id + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id LIMIT :batch_size """ result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) @@ -382,7 +382,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: # Delete the batch delete_sql = """ - DELETE FROM workflow_draft_variables + DELETE FROM workflow_draft_variables WHERE id IN :ids """ deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 524130a297..6356b1c46c 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -1,13 +1,13 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Document, DocumentSegment @@ -54,9 +54,9 @@ def remove_document_from_index_task(document_id: str): db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( { DocumentSegment.enabled: False, - DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.disabled_at: naive_utc_now(), DocumentSegment.disabled_by: document.disabled_by, - DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.updated_at: naive_utc_now(), } ) db.session.commit() diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 26b41aff2e..67af857f40 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -51,7 +51,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() redis_client.delete(retry_indexing_cache_key) @@ -79,7 +79,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.add(document) db.session.commit() @@ -89,7 +89,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index f112a97d2f..ad782f9b88 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) @@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.add(document) db.session.commit() @@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 2f9fb628ca..77ddf83023 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -8,7 +8,7 @@ improving performance by offloading storage operations to background workers. import json import logging -from celery import shared_task # type: ignore[import-untyped] +from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import sessionmaker diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index dfc8a33564..16356086cf 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -8,7 +8,7 @@ improving performance by offloading storage operations to background workers. import json import logging -from celery import shared_task # type: ignore[import-untyped] +from celery import shared_task from sqlalchemy import select from sqlalchemy.orm import sessionmaker diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9..2f0f38e0b8 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restx import Api, Resource app = Flask(__name__) api = Api(app) diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 3d7be0df7d..415e65ce51 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -1639,7 +1639,7 @@ class TestTenantService: email = fake.email() name = fake.name() password = fake.password(length=12) - invalid_action = fake.word() + invalid_action = "invalid_action_that_doesnt_exist" # Setup mocks mock_external_service_dependencies[ "feature_service" diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 8816698af8..92d93d601e 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -410,18 +410,18 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotations with specific keywords - unique_keyword = fake.word() + unique_keyword = f"unique_{fake.uuid4()[:8]}" annotation_args = { "question": f"Question with {unique_keyword} keyword", "answer": f"Answer with {unique_keyword} keyword", } AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) - # Create another annotation without the keyword other_args = { - "question": "Question without keyword", - "answer": "Answer without keyword", + "question": "Different question without special term", + "answer": "Different answer without special content", } + AppAnnotationService.insert_app_annotation_directly(other_args, app.id) # Search with keyword diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py new file mode 100644 index 0000000000..6d6f1dab72 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -0,0 +1,574 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.account import Account +from models.model import Conversation, EndUser +from models.web import PinnedConversation +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.web_conversation_service import WebConversationService + + +class TestWebConversationService: + """Integration tests for WebConversationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_test_end_user(self, db_session_with_containers, app): + """ + Helper method to create a test end user for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + + Returns: + EndUser: Created end user instance + """ + fake = Faker() + + end_user = EndUser( + session_id=fake.uuid4(), + app_id=app.id, + type="normal", + is_anonymous=False, + tenant_id=app.tenant_id, + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + return end_user + + def _create_test_conversation(self, db_session_with_containers, app, user, fake): + """ + Helper method to create a test conversation for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + user: User instance (Account or EndUser) + fake: Faker instance + + Returns: + Conversation: Created conversation instance + """ + conversation = Conversation( + app_id=app.id, + app_model_config_id=app.app_model_config_id, + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name=fake.sentence(nb_words=3), + summary=fake.text(max_nb_chars=100), + inputs={}, + introduction=fake.text(max_nb_chars=200), + system_instruction=fake.text(max_nb_chars=300), + system_instruction_tokens=50, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source="console" if isinstance(user, Account) else "api", + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + + from extensions.ext_database import db + + db.session.add(conversation) + db.session.commit() + + return conversation + + def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pagination by last ID with basic parameters. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Test pagination without pinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=3, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + sort_by="-updated_at", + ) + + # Verify results + assert result.limit == 3 + assert len(result.data) == 3 + assert result.has_more is True + + # Verify conversations are in descending order by updated_at + assert result.data[0].updated_at >= result.data[1].updated_at + assert result.data[1].updated_at >= result.data[2].updated_at + + def test_pagination_by_last_id_with_pinned_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with pinned conversation filter. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Pin some conversations + pinned_conversation1 = PinnedConversation( + app_id=app.id, + conversation_id=conversations[0].id, + created_by_role="account", + created_by=account.id, + ) + pinned_conversation2 = PinnedConversation( + app_id=app.id, + conversation_id=conversations[2].id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(pinned_conversation1) + db.session.add(pinned_conversation2) + db.session.commit() + + # Test pagination with pinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=True, + sort_by="-updated_at", + ) + + # Verify only pinned conversations are returned + assert result.limit == 10 + assert len(result.data) == 2 + assert result.has_more is False + + # Verify the returned conversations are the pinned ones + returned_ids = [conv.id for conv in result.data] + expected_ids = [conversations[0].id, conversations[2].id] + assert set(returned_ids) == set(expected_ids) + + def test_pagination_by_last_id_with_unpinned_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with unpinned conversation filter. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Pin one conversation + pinned_conversation = PinnedConversation( + app_id=app.id, + conversation_id=conversations[0].id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(pinned_conversation) + db.session.commit() + + # Test pagination with unpinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=False, + sort_by="-updated_at", + ) + + # Verify unpinned conversations are returned (should be 4 out of 5) + assert result.limit == 10 + assert len(result.data) == 4 + assert result.has_more is False + + # Verify the pinned conversation is not in the results + returned_ids = [conv.id for conv in result.data] + assert conversations[0].id not in returned_ids + + # Verify all other conversations are in the results + expected_unpinned_ids = [conv.id for conv in conversations[1:]] + assert set(returned_ids) == set(expected_unpinned_ids) + + def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pinning of a conversation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation + WebConversationService.pin(app, conversation.id, account) + + # Verify the conversation was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + assert pinned_conversation.app_id == app.id + assert pinned_conversation.conversation_id == conversation.id + assert pinned_conversation.created_by_role == "account" + assert pinned_conversation.created_by == account.id + + def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pinning a conversation that is already pinned (should not create duplicate). + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first time + WebConversationService.pin(app, conversation.id, account) + + # Pin the conversation again + WebConversationService.pin(app, conversation.id, account) + + # Verify only one pinned conversation record exists + from extensions.ext_database import db + + pinned_conversations = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .all() + ) + + assert len(pinned_conversations) == 1 + + def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pinning a conversation with an end user. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an end user + end_user = self._create_test_end_user(db_session_with_containers, app) + + # Create a conversation for the end user + conversation = self._create_test_conversation(db_session_with_containers, app, end_user, fake) + + # Pin the conversation + WebConversationService.pin(app, conversation.id, end_user) + + # Verify the conversation was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "end_user", + PinnedConversation.created_by == end_user.id, + ) + .first() + ) + + assert pinned_conversation is not None + assert pinned_conversation.app_id == app.id + assert pinned_conversation.conversation_id == conversation.id + assert pinned_conversation.created_by_role == "end_user" + assert pinned_conversation.created_by == end_user.id + + def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful unpinning of a conversation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first + WebConversationService.pin(app, conversation.id, account) + + # Verify it was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + + # Unpin the conversation + WebConversationService.unpin(app, conversation.id, account) + + # Verify it was unpinned + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test unpinning a conversation that is not pinned (should not cause error). + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Try to unpin a conversation that was never pinned + WebConversationService.unpin(app, conversation.id, account) + + # Verify no pinned conversation record exists + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_pagination_by_last_id_user_required_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that pagination_by_last_id raises ValueError when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test with None user + with pytest.raises(ValueError, match="User is required"): + WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=None, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + sort_by="-updated_at", + ) + + def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that pin method returns early when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Try to pin with None user + WebConversationService.pin(app, conversation.id, None) + + # Verify no pinned conversation was created + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that unpin method returns early when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first + WebConversationService.pin(app, conversation.id, account) + + # Verify it was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + + # Try to unpin with None user + WebConversationService.unpin(app, conversation.id, None) + + # Verify the conversation is still pinned + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py new file mode 100644 index 0000000000..666b083ba6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -0,0 +1,877 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound, Unauthorized + +from libs.password import hash_password +from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App, Site +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType + + +class TestWebAppAuthService: + """Integration tests for WebAppAuthService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.webapp_auth_service.PassportService") as mock_passport_service, + patch("services.webapp_auth_service.TokenManager") as mock_token_manager, + patch("services.webapp_auth_service.send_email_code_login_mail_task") as mock_mail_task, + patch("services.webapp_auth_service.AppService") as mock_app_service, + patch("services.webapp_auth_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + mock_token_manager.generate_token.return_value = "mock_token" + mock_token_manager.get_token_data.return_value = {"code": "123456"} + mock_mail_task.delay.return_value = None + mock_app_service.get_app_id_by_code.return_value = "mock_app_id" + mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type( + "MockWebAppAuth", (), {"access_mode": "private"} + )() + mock_enterprise_service.WebAppAuth.get_app_access_mode_by_code.return_value = type( + "MockWebAppAuth", (), {"access_mode": "private"} + )() + + yield { + "passport_service": mock_passport_service, + "token_manager": mock_token_manager, + "mail_task": mock_mail_task, + "app_service": mock_app_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account with password for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant, password) - Created account, tenant and password + """ + fake = Faker() + password = fake.password(length=12) + + # Create account with password + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + # Hash password + salt = b"test_salt_16_bytes" + password_hash = hash_password(password, salt) + + # Convert to base64 for storage + import base64 + + account.password = base64.b64encode(password_hash).decode() + account.password_salt = base64.b64encode(salt).decode() + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant, password + + def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + """ + Helper method to create a test app and site for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant: Tenant instance to associate with + + Returns: + tuple: (app, site) - Created app and site instances + """ + fake = Faker() + + # Create app + app = App( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + enable_site=True, + enable_api=True, + ) + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + + # Create site + site = Site( + app_id=app.id, + title=fake.company(), + code=fake.unique.lexify(text="??????"), + description=fake.text(max_nb_chars=100), + default_language="en-US", + status="normal", + customize_token_strategy="not_allow", + ) + db.session.add(site) + db.session.commit() + + return app, site + + def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful authentication with valid email and password. + + This test verifies: + - Proper authentication with valid credentials + - Correct account return + - Database state consistency + """ + # Arrange: Create test data + account, tenant, password = self._create_test_account_with_password( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute authentication + result = WebAppAuthService.authenticate(account.email, password) + + # Assert: Verify successful authentication + assert result is not None + assert result.id == account.id + assert result.email == account.email + assert result.name == account.name + assert result.status == AccountStatus.ACTIVE.value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.password is not None + assert result.password_salt is not None + + def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with non-existent email. + + This test verifies: + - Proper error handling for non-existent accounts + - Correct exception type and message + """ + # Arrange: Use non-existent email + fake = Faker() + non_existent_email = fake.email() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountNotFoundError): + WebAppAuthService.authenticate(non_existent_email, "any_password") + + def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with banned account. + + This test verifies: + - Proper error handling for banned accounts + - Correct exception type and message + """ + # Arrange: Create banned account + fake = Faker() + password = fake.password(length=12) + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status=AccountStatus.BANNED.value, + ) + + # Hash password + salt = b"test_salt_16_bytes" + password_hash = hash_password(password, salt) + + # Convert to base64 for storage + import base64 + + account.password = base64.b64encode(password_hash).decode() + account.password_salt = base64.b64encode(salt).decode() + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountLoginError) as exc_info: + WebAppAuthService.authenticate(account.email, password) + + assert "Account is banned." in str(exc_info.value) + + def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with invalid password. + + This test verifies: + - Proper error handling for invalid passwords + - Correct exception type and message + """ + # Arrange: Create account with password + account, tenant, correct_password = self._create_test_account_with_password( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act & Assert: Verify proper error handling with wrong password + with pytest.raises(AccountPasswordError) as exc_info: + WebAppAuthService.authenticate(account.email, "wrong_password") + + assert "Invalid email or password." in str(exc_info.value) + + def test_authenticate_account_without_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test authentication for account without password. + + This test verifies: + - Proper error handling for accounts without password + - Correct exception type and message + """ + # Arrange: Create account without password + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountPasswordError) as exc_info: + WebAppAuthService.authenticate(account.email, "any_password") + + assert "Invalid email or password." in str(exc_info.value) + + def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful login and JWT token generation. + + This test verifies: + - Proper JWT token generation + - Correct token format and content + - Mock service integration + """ + # Arrange: Create test account + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute login + result = WebAppAuthService.login(account) + + # Assert: Verify successful login + assert result is not None + assert result == "mock_jwt_token" + + # Verify mock service was called correctly + mock_external_service_dependencies["passport_service"].return_value.issue.assert_called_once() + call_args = mock_external_service_dependencies["passport_service"].return_value.issue.call_args[0][0] + + assert call_args["sub"] == "Web API Passport" + assert call_args["user_id"] == account.id + assert call_args["session_id"] == account.email + assert call_args["token_source"] == "webapp_login_token" + assert call_args["auth_type"] == "internal" + assert "exp" in call_args + + def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful user retrieval through email. + + This test verifies: + - Proper user retrieval by email + - Correct account return + - Database state consistency + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute user retrieval + result = WebAppAuthService.get_user_through_email(account.email) + + # Assert: Verify successful retrieval + assert result is not None + assert result.id == account.id + assert result.email == account.email + assert result.name == account.name + assert result.status == AccountStatus.ACTIVE.value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + + def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test user retrieval with non-existent email. + + This test verifies: + - Proper handling for non-existent users + - Correct return value (None) + """ + # Arrange: Use non-existent email + fake = Faker() + non_existent_email = fake.email() + + # Act: Execute user retrieval + result = WebAppAuthService.get_user_through_email(non_existent_email) + + # Assert: Verify proper handling + assert result is None + + def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test user retrieval with banned account. + + This test verifies: + - Proper error handling for banned accounts + - Correct exception type and message + """ + # Arrange: Create banned account + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status=AccountStatus.BANNED.value, + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(Unauthorized) as exc_info: + WebAppAuthService.get_user_through_email(account.email) + + assert "Account is banned." in str(exc_info.value) + + def test_send_email_code_login_email_with_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email with account. + + This test verifies: + - Proper email code generation + - Token generation with correct data + - Mail task scheduling + - Mock service integration + """ + # Arrange: Create test account + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute email code login email sending + result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") + + # Assert: Verify successful email sending + assert result is not None + assert result == "mock_token" + + # Verify mock services were called correctly + mock_external_service_dependencies["token_manager"].generate_token.assert_called_once() + mock_external_service_dependencies["mail_task"].delay.assert_called_once() + + # Verify token generation parameters + token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args + assert token_call_args[1]["account"] == account + assert token_call_args[1]["email"] == account.email + assert token_call_args[1]["token_type"] == "email_code_login" + assert "code" in token_call_args[1]["additional_data"] + + # Verify mail task parameters + mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args + assert mail_call_args[1]["language"] == "en-US" + assert mail_call_args[1]["to"] == account.email + assert "code" in mail_call_args[1] + + def test_send_email_code_login_email_with_email_only( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email with email only. + + This test verifies: + - Proper email code generation without account + - Token generation with email only + - Mail task scheduling + - Mock service integration + """ + # Arrange: Use test email + fake = Faker() + test_email = fake.email() + + # Act: Execute email code login email sending + result = WebAppAuthService.send_email_code_login_email(email=test_email, language="zh-Hans") + + # Assert: Verify successful email sending + assert result is not None + assert result == "mock_token" + + # Verify mock services were called correctly + mock_external_service_dependencies["token_manager"].generate_token.assert_called_once() + mock_external_service_dependencies["mail_task"].delay.assert_called_once() + + # Verify token generation parameters + token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args + assert token_call_args[1]["account"] is None + assert token_call_args[1]["email"] == test_email + assert token_call_args[1]["token_type"] == "email_code_login" + assert "code" in token_call_args[1]["additional_data"] + + # Verify mail task parameters + mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args + assert mail_call_args[1]["language"] == "zh-Hans" + assert mail_call_args[1]["to"] == test_email + assert "code" in mail_call_args[1] + + def test_send_email_code_login_email_no_email_provided( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email without providing email. + + This test verifies: + - Proper error handling when no email is provided + - Correct exception type and message + """ + # Arrange: No email provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.send_email_code_login_email() + + assert "Email must be provided." in str(exc_info.value) + + def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of email code login data. + + This test verifies: + - Proper token data retrieval + - Correct data format + - Mock service integration + """ + # Arrange: Setup mock return + expected_data = {"code": "123456", "email": "test@example.com"} + mock_external_service_dependencies["token_manager"].get_token_data.return_value = expected_data + + # Act: Execute data retrieval + result = WebAppAuthService.get_email_code_login_data("mock_token") + + # Assert: Verify successful retrieval + assert result is not None + assert result == expected_data + assert result["code"] == "123456" + assert result["email"] == "test@example.com" + + # Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with( + "mock_token", "email_code_login" + ) + + def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test email code login data retrieval when no data exists. + + This test verifies: + - Proper handling when no token data exists + - Correct return value (None) + - Mock service integration + """ + # Arrange: Setup mock return for no data + mock_external_service_dependencies["token_manager"].get_token_data.return_value = None + + # Act: Execute data retrieval + result = WebAppAuthService.get_email_code_login_data("invalid_token") + + # Assert: Verify proper handling + assert result is None + + # Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with( + "invalid_token", "email_code_login" + ) + + def test_revoke_email_code_login_token_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful revocation of email code login token. + + This test verifies: + - Proper token revocation + - Mock service integration + """ + # Arrange: Setup mock + + # Act: Execute token revocation + WebAppAuthService.revoke_email_code_login_token("mock_token") + + # Assert: Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].revoke_token.assert_called_once_with( + "mock_token", "email_code_login" + ) + + def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful end user creation. + + This test verifies: + - Proper end user creation with valid app code + - Correct database state after creation + - Proper relationship establishment + - Mock service integration + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + app, site = self._create_test_app_and_site( + db_session_with_containers, mock_external_service_dependencies, tenant + ) + + # Act: Execute end user creation + result = WebAppAuthService.create_end_user(site.code, "test@example.com") + + # Assert: Verify successful creation + assert result is not None + assert result.tenant_id == app.tenant_id + assert result.app_id == app.id + assert result.type == "browser" + assert result.is_anonymous is False + assert result.session_id == "test@example.com" + assert result.name == "enterpriseuser" + assert result.external_user_id == "enterpriseuser" + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.created_at is not None + assert result.updated_at is not None + + def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test end user creation with non-existent site code. + + This test verifies: + - Proper error handling for non-existent sites + - Correct exception type and message + """ + # Arrange: Use non-existent site code + fake = Faker() + non_existent_code = fake.unique.lexify(text="??????") + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + WebAppAuthService.create_end_user(non_existent_code, "test@example.com") + + assert "Site not found." in str(exc_info.value) + + def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test end user creation when app is not found. + + This test verifies: + - Proper error handling when app is missing + - Correct exception type and message + """ + # Arrange: Create site without app + fake = Faker() + tenant = Tenant( + name=fake.company(), + status="normal", + ) + + from extensions.ext_database import db + + db.session.add(tenant) + db.session.commit() + + site = Site( + app_id="00000000-0000-0000-0000-000000000000", + title=fake.company(), + code=fake.unique.lexify(text="??????"), + description=fake.text(max_nb_chars=100), + default_language="en-US", + status="normal", + customize_token_strategy="not_allow", + ) + db.session.add(site) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + WebAppAuthService.create_end_user(site.code, "test@example.com") + + assert "App not found." in str(exc_info.value) + + def test_is_app_require_permission_check_with_access_mode_private( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement for private access mode. + + This test verifies: + - Proper permission check requirement for private mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with private access mode + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(access_mode="private") + + # Assert: Verify correct result + assert result is True + + def test_is_app_require_permission_check_with_access_mode_public( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement for public access mode. + + This test verifies: + - Proper permission check requirement for public mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with public access mode + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(access_mode="public") + + # Assert: Verify correct result + assert result is False + + def test_is_app_require_permission_check_with_app_code( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement using app code. + + This test verifies: + - Proper permission check requirement using app code + - Correct return value + - Mock service integration + """ + # Arrange: Setup mock for app service + mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id" + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(app_code="mock_app_code") + + # Assert: Verify correct result + assert result is True + + # Verify mock service was called correctly + mock_external_service_dependencies["app_service"].get_app_id_by_code.assert_called_once_with("mock_app_code") + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") + + def test_is_app_require_permission_check_no_parameters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement with no parameters. + + This test verifies: + - Proper error handling when no parameters provided + - Correct exception type and message + """ + # Arrange: No parameters provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.is_app_require_permission_check() + + assert "Either app_code or app_id must be provided." in str(exc_info.value) + + def test_get_app_auth_type_with_access_mode_public( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test app authentication type for public access mode. + + This test verifies: + - Proper authentication type determination for public mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with public access mode + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(access_mode="public") + + # Assert: Verify correct result + assert result == WebAppAuthType.PUBLIC + + def test_get_app_auth_type_with_access_mode_private( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test app authentication type for private access mode. + + This test verifies: + - Proper authentication type determination for private mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with private access mode + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(access_mode="private") + + # Assert: Verify correct result + assert result == WebAppAuthType.INTERNAL + + def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app authentication type using app code. + + This test verifies: + - Proper authentication type determination using app code + - Correct return value + - Mock service integration + """ + # Arrange: Setup mock for enterprise service + mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") + + # Assert: Verify correct result + assert result == WebAppAuthType.EXTERNAL + + # Verify mock service was called correctly + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code") + + def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app authentication type with no parameters. + + This test verifies: + - Proper error handling when no parameters provided + - Correct exception type and message + """ + # Arrange: No parameters provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.get_app_auth_type() + + assert "Either app_code or access_mode must be provided." in str(exc_info.value) diff --git a/api/tests/test_containers_integration_tests/services/test_website_service.py b/api/tests/test_containers_integration_tests/services/test_website_service.py new file mode 100644 index 0000000000..ec2f1556af --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_website_service.py @@ -0,0 +1,1437 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from services.website_service import ( + CrawlOptions, + ScrapeRequest, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +class TestWebsiteService: + """Integration tests for WebsiteService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.website_service.ApiKeyAuthService") as mock_api_key_auth_service, + patch("services.website_service.FirecrawlApp") as mock_firecrawl_app, + patch("services.website_service.WaterCrawlProvider") as mock_watercrawl_provider, + patch("services.website_service.requests") as mock_requests, + patch("services.website_service.redis_client") as mock_redis_client, + patch("services.website_service.storage") as mock_storage, + patch("services.website_service.encrypter") as mock_encrypter, + ): + # Setup default mock returns + mock_api_key_auth_service.get_auth_credentials.return_value = { + "config": {"api_key": "encrypted_api_key", "base_url": "https://api.example.com"} + } + mock_encrypter.decrypt_token.return_value = "decrypted_api_key" + + # Mock FirecrawlApp + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.crawl_url.return_value = "test_job_id_123" + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "completed", + "total": 5, + "current": 5, + "data": [{"source_url": "https://example.com", "title": "Test Page"}], + } + mock_firecrawl_app.return_value = mock_firecrawl_instance + + # Mock WaterCrawlProvider + mock_watercrawl_instance = MagicMock() + mock_watercrawl_instance.crawl_url.return_value = {"status": "active", "job_id": "watercrawl_job_123"} + mock_watercrawl_instance.get_crawl_status.return_value = { + "status": "completed", + "job_id": "watercrawl_job_123", + "total": 3, + "current": 3, + "data": [], + } + mock_watercrawl_instance.get_crawl_url_data.return_value = { + "title": "WaterCrawl Page", + "source_url": "https://example.com", + "description": "Test description", + "markdown": "# Test Content", + } + mock_watercrawl_instance.scrape_url.return_value = { + "title": "Scraped Page", + "content": "Test content", + "url": "https://example.com", + } + mock_watercrawl_provider.return_value = mock_watercrawl_instance + + # Mock requests + mock_response = MagicMock() + mock_response.json.return_value = {"code": 200, "data": {"taskId": "jina_job_123"}} + mock_requests.get.return_value = mock_response + mock_requests.post.return_value = mock_response + + # Mock Redis + mock_redis_client.setex.return_value = None + mock_redis_client.get.return_value = str(datetime.now().timestamp()) + mock_redis_client.delete.return_value = None + + # Mock Storage + mock_storage.exists.return_value = False + mock_storage.load_once.return_value = None + + yield { + "api_key_auth_service": mock_api_key_auth_service, + "firecrawl_app": mock_firecrawl_app, + "watercrawl_provider": mock_watercrawl_provider, + "requests": mock_requests, + "redis_client": mock_redis_client, + "storage": mock_storage, + "encrypter": mock_encrypter, + } + + def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account with proper tenant setup. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + Account: Created account instance + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account + + def test_document_create_args_validate_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful argument validation for document creation. + + This test verifies: + - Valid arguments are accepted without errors + - All required fields are properly validated + - Optional fields are handled correctly + """ + # Arrange: Prepare valid arguments + valid_args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 5, + "crawl_sub_pages": True, + "only_main_content": False, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 3, + "use_sitemap": True, + }, + } + + # Act: Validate arguments + WebsiteService.document_create_args_validate(valid_args) + + # Assert: No exception should be raised + # If we reach here, validation passed successfully + + def test_document_create_args_validate_missing_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test argument validation fails when provider is missing. + + This test verifies: + - Missing provider raises ValueError + - Proper error message is provided + - Validation stops at first missing required field + """ + # Arrange: Prepare arguments without provider + invalid_args = {"url": "https://example.com", "options": {"limit": 5, "crawl_sub_pages": True}} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.document_create_args_validate(invalid_args) + + assert "Provider is required" in str(exc_info.value) + + def test_document_create_args_validate_missing_url( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test argument validation fails when URL is missing. + + This test verifies: + - Missing URL raises ValueError + - Proper error message is provided + - Validation continues after provider check + """ + # Arrange: Prepare arguments without URL + invalid_args = {"provider": "firecrawl", "options": {"limit": 5, "crawl_sub_pages": True}} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.document_create_args_validate(invalid_args) + + assert "URL is required" in str(exc_info.value) + + def test_crawl_url_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with Firecrawl provider. + + This test verifies: + - Firecrawl provider is properly initialized + - API credentials are retrieved and decrypted + - Crawl parameters are correctly formatted + - Job ID is returned with active status + - Redis cache is properly set + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + fake = Faker() + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "limit": 10, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 2, + "use_sitemap": True, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "test_job_id_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + mock_external_service_dependencies["firecrawl_app"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + # Verify Redis cache was set + mock_external_service_dependencies["redis_client"].setex.assert_called_once() + + def test_crawl_url_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with WaterCrawl provider. + + This test verifies: + - WaterCrawl provider is properly initialized + - API credentials are retrieved and decrypted + - Crawl options are correctly passed to provider + - Provider returns expected response format + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 5, + "crawl_sub_pages": False, + "only_main_content": False, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": False, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "watercrawl_job_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + def test_crawl_url_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with JinaReader provider. + + This test verifies: + - JinaReader provider handles single page crawling + - API credentials are retrieved and decrypted + - HTTP requests are made with proper headers + - Response is properly parsed and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request for single page crawling + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={ + "limit": 1, + "crawl_sub_pages": False, + "only_main_content": True, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": False, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["data"] is not None + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].get.assert_called_once_with( + "https://r.jina.ai/https://example.com", + headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_crawl_url_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl operation fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request with invalid provider + api_request = WebsiteCrawlApiRequest( + provider="invalid_provider", + url="https://example.com", + options={"limit": 5, "crawl_sub_pages": False, "only_main_content": False}, + ) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.crawl_url(api_request) + + assert "Invalid provider" in str(exc_info.value) + + def test_get_crawl_status_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with Firecrawl provider. + + This test verifies: + - Firecrawl status is properly retrieved + - API credentials are retrieved and decrypted + - Status data includes all required fields + - Redis cache is properly managed for completed jobs + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "completed" + assert result["job_id"] == "test_job_id_123" + assert result["total"] == 5 + assert result["current"] == 5 + assert "data" in result + assert "time_consuming" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify Redis cache was accessed and cleaned up + mock_external_service_dependencies["redis_client"].get.assert_called_once() + mock_external_service_dependencies["redis_client"].delete.assert_called_once() + + def test_get_crawl_status_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with WaterCrawl provider. + + This test verifies: + - WaterCrawl status is properly retrieved + - API credentials are retrieved and decrypted + - Provider returns expected status format + - All required status fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "completed" + assert result["job_id"] == "watercrawl_job_123" + assert result["total"] == 3 + assert result["current"] == 3 + assert "data" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + def test_get_crawl_status_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with JinaReader provider. + + This test verifies: + - JinaReader status is properly retrieved + - API credentials are retrieved and decrypted + - HTTP requests are made with proper parameters + - Status data is properly formatted and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "jina_job_123" + assert "total" in result + assert "current" in result + assert "data" in result + assert "time_consuming" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].post.assert_called_once() + + def test_get_crawl_status_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request with invalid provider + api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "Invalid provider" in str(exc_info.value) + + def test_get_crawl_status_missing_credentials(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails when credentials are missing. + + This test verifies: + - Missing credentials raises ValueError + - Proper error message is provided + - Service handles authentication failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Mock missing credentials + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "No valid credentials found for the provider" in str(exc_info.value) + + def test_get_crawl_status_missing_api_key(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails when API key is missing from config. + + This test verifies: + - Missing API key raises ValueError + - Proper error message is provided + - Service handles configuration failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Mock missing API key in config + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { + "config": {"base_url": "https://api.example.com"} + } + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "API key not found in configuration" in str(exc_info.value) + + def test_get_crawl_url_data_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL data retrieval with Firecrawl provider. + + This test verifies: + - Firecrawl URL data is properly retrieved + - API credentials are retrieved and decrypted + - Data is returned for matching URL + - Storage fallback works when needed + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return existing data + mock_external_service_dependencies["storage"].exists.return_value = True + mock_external_service_dependencies["storage"].load_once.return_value = ( + b"[" + b'{"source_url": "https://example.com", "title": "Test Page", ' + b'"description": "Test Description", "markdown": "# Test Content"}' + b"]" + ) + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com" + assert result["title"] == "Test Page" + assert result["description"] == "Test Description" + assert result["markdown"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify storage was accessed + mock_external_service_dependencies["storage"].exists.assert_called_once() + mock_external_service_dependencies["storage"].load_once.assert_called_once() + + def test_get_crawl_url_data_watercrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL data retrieval with WaterCrawl provider. + + This test verifies: + - WaterCrawl URL data is properly retrieved + - API credentials are retrieved and decrypted + - Provider returns expected data format + - All required data fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="watercrawl_job_123", + provider="watercrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "WaterCrawl Page" + assert result["source_url"] == "https://example.com" + assert result["description"] == "Test description" + assert result["markdown"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + def test_get_crawl_url_data_jinareader_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL data retrieval with JinaReader provider. + + This test verifies: + - JinaReader URL data is properly retrieved + - API credentials are retrieved and decrypted + - HTTP requests are made with proper parameters + - Data is properly formatted and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock successful response for JinaReader + mock_response = MagicMock() + mock_response.json.return_value = { + "code": 200, + "data": { + "title": "JinaReader Page", + "url": "https://example.com", + "description": "Test description", + "content": "# Test Content", + }, + } + mock_external_service_dependencies["requests"].get.return_value = mock_response + + # Act: Get URL data without job_id (single page scraping) + result = WebsiteService.get_crawl_url_data( + job_id="", provider="jinareader", url="https://example.com", tenant_id=account.current_tenant.id + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "JinaReader Page" + assert result["url"] == "https://example.com" + assert result["description"] == "Test description" + assert result["content"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].get.assert_called_once_with( + "https://r.jina.ai/https://example.com", + headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_get_scrape_url_data_firecrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL scraping with Firecrawl provider. + + This test verifies: + - Firecrawl scraping is properly executed + - API credentials are retrieved and decrypted + - Scraping parameters are correctly passed + - Scraped data is returned in expected format + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock FirecrawlApp scraping response + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.scrape_url.return_value = { + "title": "Scraped Page Title", + "content": "This is the scraped content", + "url": "https://example.com", + "description": "Page description", + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act: Scrape URL + result = WebsiteService.get_scrape_url_data( + provider="firecrawl", url="https://example.com", tenant_id=account.current_tenant.id, only_main_content=True + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Scraped Page Title" + assert result["content"] == "This is the scraped content" + assert result["url"] == "https://example.com" + assert result["description"] == "Page description" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify FirecrawlApp was called with correct parameters + mock_external_service_dependencies["firecrawl_app"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + mock_firecrawl_instance.scrape_url.assert_called_once_with( + url="https://example.com", params={"onlyMainContent": True} + ) + + def test_get_scrape_url_data_watercrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL scraping with WaterCrawl provider. + + This test verifies: + - WaterCrawl scraping is properly executed + - API credentials are retrieved and decrypted + - Provider returns expected scraping format + - All required data fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act: Scrape URL + result = WebsiteService.get_scrape_url_data( + provider="watercrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + only_main_content=False, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Scraped Page" + assert result["content"] == "Test content" + assert result["url"] == "https://example.com" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify WaterCrawlProvider was called with correct parameters + mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + def test_get_scrape_url_data_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test URL scraping fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_scrape_url_data( + provider="invalid_provider", + url="https://example.com", + tenant_id=account.current_tenant.id, + only_main_content=False, + ) + + assert "Invalid provider" in str(exc_info.value) + + def test_crawl_options_include_exclude_paths(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test CrawlOptions include and exclude path methods. + + This test verifies: + - Include paths are properly parsed from comma-separated string + - Exclude paths are properly parsed from comma-separated string + - Empty or None values are handled correctly + - Path lists are returned in expected format + """ + # Arrange: Create CrawlOptions with various path configurations + options_with_paths = CrawlOptions(includes="blog,news,articles", excludes="admin,private,test") + + options_without_paths = CrawlOptions(includes=None, excludes="") + + # Act: Get include and exclude paths + include_paths = options_with_paths.get_include_paths() + exclude_paths = options_with_paths.get_exclude_paths() + + empty_include_paths = options_without_paths.get_include_paths() + empty_exclude_paths = options_without_paths.get_exclude_paths() + + # Assert: Verify path parsing + assert include_paths == ["blog", "news", "articles"] + assert exclude_paths == ["admin", "private", "test"] + assert empty_include_paths == [] + assert empty_exclude_paths == [] + + def test_website_crawl_api_request_conversion(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test WebsiteCrawlApiRequest conversion to CrawlRequest. + + This test verifies: + - API request is properly converted to internal CrawlRequest + - All options are correctly mapped + - Default values are applied when options are missing + - Conversion maintains data integrity + """ + # Arrange: Create API request with various options + api_request = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "limit": 10, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 3, + "use_sitemap": False, + }, + ) + + # Act: Convert to CrawlRequest + crawl_request = api_request.to_crawl_request() + + # Assert: Verify conversion + assert crawl_request.url == "https://example.com" + assert crawl_request.provider == "firecrawl" + assert crawl_request.options.limit == 10 + assert crawl_request.options.crawl_sub_pages is True + assert crawl_request.options.only_main_content is True + assert crawl_request.options.includes == "blog,news" + assert crawl_request.options.excludes == "admin,private" + assert crawl_request.options.max_depth == 3 + assert crawl_request.options.use_sitemap is False + + def test_website_crawl_api_request_from_args(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test WebsiteCrawlApiRequest creation from Flask arguments. + + This test verifies: + - Request is properly created from parsed arguments + - Required fields are validated + - Optional fields are handled correctly + - Validation errors are properly raised + """ + # Arrange: Prepare valid arguments + valid_args = {"provider": "watercrawl", "url": "https://example.com", "options": {"limit": 5}} + + # Act: Create request from args + request = WebsiteCrawlApiRequest.from_args(valid_args) + + # Assert: Verify request creation + assert request.provider == "watercrawl" + assert request.url == "https://example.com" + assert request.options == {"limit": 5} + + # Test missing provider + invalid_args = {"url": "https://example.com", "options": {}} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "Provider is required" in str(exc_info.value) + + # Test missing URL + invalid_args = {"provider": "watercrawl", "options": {}} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "URL is required" in str(exc_info.value) + + # Test missing options + invalid_args = {"provider": "watercrawl", "url": "https://example.com"} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "Options are required" in str(exc_info.value) + + def test_crawl_url_jinareader_sub_pages_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL crawling with JinaReader provider for sub-pages. + + This test verifies: + - JinaReader provider handles sub-page crawling correctly + - HTTP POST request is made with proper parameters + - Job ID is returned for multi-page crawling + - All required parameters are passed correctly + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request for sub-page crawling + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={ + "limit": 5, + "crawl_sub_pages": True, + "only_main_content": False, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": True, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "jina_job_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP POST request was made for sub-page crawling + mock_external_service_dependencies["requests"].post.assert_called_once_with( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={"url": "https://example.com", "maxPages": 5, "useSitemap": True}, + headers={"Content-Type": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_crawl_url_jinareader_failed_response(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test JinaReader crawling fails when API returns error. + + This test verifies: + - Failed API response raises ValueError + - Proper error message is provided + - Service handles API failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock failed response + mock_failed_response = MagicMock() + mock_failed_response.json.return_value = {"code": 500, "error": "Internal server error"} + mock_external_service_dependencies["requests"].get.return_value = mock_failed_response + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"limit": 1, "crawl_sub_pages": False, "only_main_content": True}, + ) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.crawl_url(api_request) + + assert "Failed to crawl" in str(exc_info.value) + + def test_get_crawl_status_firecrawl_active_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl status retrieval for active (not completed) job. + + This test verifies: + - Active job status is properly returned + - Redis cache is not deleted for active jobs + - Time consuming is not calculated for active jobs + - All required status fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock active job status + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "active", + "total": 10, + "current": 3, + "data": [], + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify active job status + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "active_job_123" + assert result["total"] == 10 + assert result["current"] == 3 + assert "data" in result + assert "time_consuming" not in result + + # Verify Redis cache was not accessed for active jobs + mock_external_service_dependencies["redis_client"].get.assert_not_called() + mock_external_service_dependencies["redis_client"].delete.assert_not_called() + + def test_get_crawl_url_data_firecrawl_storage_fallback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval with storage fallback. + + This test verifies: + - Storage fallback works when storage has data + - API call is not made when storage has data + - Data is properly parsed from storage + - Correct URL data is returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return existing data + mock_external_service_dependencies["storage"].exists.return_value = True + mock_external_service_dependencies["storage"].load_once.return_value = ( + b"[" + b'{"source_url": "https://example.com/page1", ' + b'"title": "Page 1", "description": "Description 1", "markdown": "# Page 1"}, ' + b'{"source_url": "https://example.com/page2", "title": "Page 2", ' + b'"description": "Description 2", "markdown": "# Page 2"}' + b"]" + ) + + # Act: Get URL data for specific URL + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/page1", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com/page1" + assert result["title"] == "Page 1" + assert result["description"] == "Description 1" + assert result["markdown"] == "# Page 1" + + # Verify storage was accessed + mock_external_service_dependencies["storage"].exists.assert_called_once() + mock_external_service_dependencies["storage"].load_once.assert_called_once() + + def test_get_crawl_url_data_firecrawl_api_fallback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval with API fallback when storage is empty. + + This test verifies: + - API fallback works when storage has no data + - FirecrawlApp is called to get data + - Completed job status is checked + - Data is returned from API response + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return no data + mock_external_service_dependencies["storage"].exists.return_value = False + + # Mock FirecrawlApp for API fallback + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "completed", + "data": [ + { + "source_url": "https://example.com/api_page", + "title": "API Page", + "description": "API Description", + "markdown": "# API Content", + } + ], + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/api_page", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com/api_page" + assert result["title"] == "API Page" + assert result["description"] == "API Description" + assert result["markdown"] == "# API Content" + + # Verify API was called + mock_external_service_dependencies["firecrawl_app"].assert_called_once() + + def test_get_crawl_url_data_firecrawl_incomplete_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval fails for incomplete job. + + This test verifies: + - Incomplete job raises ValueError + - Proper error message is provided + - Service handles incomplete jobs gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return no data + mock_external_service_dependencies["storage"].exists.return_value = False + + # Mock incomplete job status + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = {"status": "active", "data": []} + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/page", + tenant_id=account.current_tenant.id, + ) + + assert "Crawl job is not completed" in str(exc_info.value) + + def test_get_crawl_url_data_jinareader_with_job_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test JinaReader URL data retrieval with job ID for multi-page crawling. + + This test verifies: + - JinaReader handles job ID-based data retrieval + - Status check is performed before data retrieval + - Processed data is properly formatted + - Correct URL data is returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock successful status response + mock_status_response = MagicMock() + mock_status_response.json.return_value = { + "code": 200, + "data": { + "status": "completed", + "processed": { + "https://example.com/page1": { + "data": { + "title": "Page 1", + "url": "https://example.com/page1", + "description": "Description 1", + "content": "# Content 1", + } + } + }, + }, + } + mock_external_service_dependencies["requests"].post.return_value = mock_status_response + + # Act: Get URL data with job ID + result = WebsiteService.get_crawl_url_data( + job_id="jina_job_123", + provider="jinareader", + url="https://example.com/page1", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Page 1" + assert result["url"] == "https://example.com/page1" + assert result["description"] == "Description 1" + assert result["content"] == "# Content 1" + + # Verify HTTP requests were made + assert mock_external_service_dependencies["requests"].post.call_count == 2 + + def test_get_crawl_url_data_jinareader_incomplete_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test JinaReader URL data retrieval fails for incomplete job. + + This test verifies: + - Incomplete job raises ValueError + - Proper error message is provided + - Service handles incomplete jobs gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock incomplete job status + mock_status_response = MagicMock() + mock_status_response.json.return_value = {"code": 200, "data": {"status": "active", "processed": {}}} + mock_external_service_dependencies["requests"].post.return_value = mock_status_response + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_url_data( + job_id="jina_job_123", + provider="jinareader", + url="https://example.com/page", + tenant_id=account.current_tenant.id, + ) + + assert "Crawl job is not completed" in str(exc_info.value) + + def test_crawl_options_default_values(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test CrawlOptions default values and initialization. + + This test verifies: + - Default values are properly set + - Optional fields can be None + - Boolean fields have correct defaults + - Integer fields have correct defaults + """ + # Arrange: Create CrawlOptions with minimal parameters + options = CrawlOptions() + + # Assert: Verify default values + assert options.limit == 1 + assert options.crawl_sub_pages is False + assert options.only_main_content is False + assert options.includes is None + assert options.excludes is None + assert options.max_depth is None + assert options.use_sitemap is True + + # Test with custom values + custom_options = CrawlOptions( + limit=10, + crawl_sub_pages=True, + only_main_content=True, + includes="blog,news", + excludes="admin", + max_depth=3, + use_sitemap=False, + ) + + assert custom_options.limit == 10 + assert custom_options.crawl_sub_pages is True + assert custom_options.only_main_content is True + assert custom_options.includes == "blog,news" + assert custom_options.excludes == "admin" + assert custom_options.max_depth == 3 + assert custom_options.use_sitemap is False + + def test_website_crawl_status_api_request_from_args( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test WebsiteCrawlStatusApiRequest creation from Flask arguments. + + This test verifies: + - Request is properly created from parsed arguments + - Required fields are validated + - Job ID is properly handled + - Validation errors are properly raised + """ + # Arrange: Prepare valid arguments + valid_args = {"provider": "firecrawl"} + job_id = "test_job_123" + + # Act: Create request from args + request = WebsiteCrawlStatusApiRequest.from_args(valid_args, job_id) + + # Assert: Verify request creation + assert request.provider == "firecrawl" + assert request.job_id == "test_job_123" + + # Test missing provider + invalid_args = {} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlStatusApiRequest.from_args(invalid_args, job_id) + assert "Provider is required" in str(exc_info.value) + + # Test missing job ID + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlStatusApiRequest.from_args(valid_args, "") + assert "Job ID is required" in str(exc_info.value) + + def test_scrape_request_initialization(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test ScrapeRequest dataclass initialization and properties. + + This test verifies: + - ScrapeRequest is properly initialized + - All fields are correctly set + - Boolean field works correctly + - String fields are properly assigned + """ + # Arrange: Create ScrapeRequest + request = ScrapeRequest( + provider="firecrawl", url="https://example.com", tenant_id="tenant_123", only_main_content=True + ) + + # Assert: Verify initialization + assert request.provider == "firecrawl" + assert request.url == "https://example.com" + assert request.tenant_id == "tenant_123" + assert request.only_main_content is True + + # Test with different values + request2 = ScrapeRequest( + provider="watercrawl", url="https://test.com", tenant_id="tenant_456", only_main_content=False + ) + + assert request2.provider == "watercrawl" + assert request2.url == "https://test.com" + assert request2.tenant_id == "tenant_456" + assert request2.only_main_content is False diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f26be6702a..ac3c8e45c9 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -1,9 +1,8 @@ -import datetime import uuid from collections import OrderedDict from typing import Any, NamedTuple -from flask_restful import marshal +from flask_restx import marshal from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -13,6 +12,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from libs.datetime_utils import naive_utc_now from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList @@ -57,7 +57,7 @@ class TestWorkflowDraftVariableFields: ) sys_var.id = str(uuid.uuid4()) - sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + sys_var.last_edited_at = naive_utc_now() sys_var.visible = True expected_without_value = OrderedDict( @@ -88,7 +88,7 @@ class TestWorkflowDraftVariableFields: ) node_var.id = str(uuid.uuid4()) - node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + node_var.last_edited_at = naive_utc_now() expected_without_value: OrderedDict[str, Any] = OrderedDict( { diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py new file mode 100644 index 0000000000..c10f7b89c3 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py @@ -0,0 +1,148 @@ +"""Tests for LLMUsage entity.""" + +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + + +class TestLLMUsage: + """Test cases for LLMUsage class.""" + + def test_from_metadata_with_all_tokens(self): + """Test from_metadata when all token types are provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": 0.001, + "completion_unit_price": 0.002, + "total_price": 0.2, + "currency": "USD", + "latency": 1.5, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.total_price == Decimal("0.2") + assert usage.currency == "USD" + assert usage.latency == 1.5 + + def test_from_metadata_with_prompt_tokens_only(self): + """Test from_metadata when only prompt_tokens is provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "total_tokens": 100, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 100 + + def test_from_metadata_with_completion_tokens_only(self): + """Test from_metadata when only completion_tokens is provided.""" + metadata: LLMUsageMetadata = { + "completion_tokens": 50, + "total_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 50 + + def test_from_metadata_calculates_total_when_missing(self): + """Test from_metadata calculates total_tokens when not provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 # Should be calculated + + def test_from_metadata_with_total_but_no_completion(self): + """ + Test from_metadata when total_tokens is provided but completion_tokens is 0. + This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 479, + "completion_tokens": 0, + "total_tokens": 521, + } + + usage = LLMUsage.from_metadata(metadata) + + # This is the key fix - prompt tokens should remain as prompt tokens + assert usage.prompt_tokens == 479 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 521 + + def test_from_metadata_with_empty_metadata(self): + """Test from_metadata with empty metadata.""" + metadata: LLMUsageMetadata = {} + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.currency == "USD" + assert usage.latency == 0.0 + + def test_from_metadata_preserves_zero_completion_tokens(self): + """ + Test that zero completion_tokens are preserved when explicitly set. + This is important for agent nodes that only use prompt tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 1000, + "completion_tokens": 0, + "total_tokens": 1000, + "prompt_unit_price": 0.15, + "completion_unit_price": 0.60, + "prompt_price": 0.00015, + "completion_price": 0, + "total_price": 0.00015, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 1000 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 1000 + assert usage.prompt_price == Decimal("0.00015") + assert usage.completion_price == Decimal(0) + assert usage.total_price == Decimal("0.00015") + + def test_from_metadata_with_decimal_values(self): + """Test from_metadata handles decimal values correctly.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": "0.001", + "completion_unit_price": "0.002", + "prompt_price": "0.1", + "completion_price": "0.1", + "total_price": "0.2", + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.prompt_price == Decimal("0.1") + assert usage.completion_price == Decimal("0.1") + assert usage.total_price == Decimal("0.2") diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index 450501c256..e7733b2317 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality for workflow execution data. """ -from datetime import UTC, datetime from unittest.mock import Mock, patch from uuid import uuid4 @@ -13,6 +12,7 @@ import pytest from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType +from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom @@ -56,7 +56,7 @@ def sample_workflow_execution(): workflow_version="1.0", graph={"nodes": [], "edges": []}, inputs={"input1": "value1"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) @@ -199,7 +199,7 @@ class TestCeleryWorkflowExecutionRepository: workflow_version="1.0", graph={"nodes": [], "edges": []}, inputs={"input1": "value1"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) exec2 = WorkflowExecution.new( id_=str(uuid4()), @@ -208,7 +208,7 @@ class TestCeleryWorkflowExecutionRepository: workflow_version="1.0", graph={"nodes": [], "edges": []}, inputs={"input2": "value2"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Save both executions @@ -235,7 +235,7 @@ class TestCeleryWorkflowExecutionRepository: workflow_version="1.0", graph={"nodes": [], "edges": []}, inputs={"input1": "value1"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) repo.save(execution) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index b38d994f03..0c6fdc8f92 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -5,7 +5,6 @@ These tests verify the Celery-based asynchronous storage functionality for workflow node execution data. """ -from datetime import UTC, datetime from unittest.mock import Mock, patch from uuid import uuid4 @@ -18,6 +17,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -65,7 +65,7 @@ def sample_workflow_node_execution(): title="Test Node", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) @@ -263,7 +263,7 @@ class TestCeleryWorkflowNodeExecutionRepository: title="Node 1", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) exec2 = WorkflowNodeExecution( id=str(uuid4()), @@ -276,7 +276,7 @@ class TestCeleryWorkflowNodeExecutionRepository: title="Node 2", inputs={"input2": "value2"}, status=WorkflowNodeExecutionStatus.RUNNING, - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Save both executions @@ -314,7 +314,7 @@ class TestCeleryWorkflowNodeExecutionRepository: title="Node 2", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) exec2 = WorkflowNodeExecution( id=str(uuid4()), @@ -327,7 +327,7 @@ class TestCeleryWorkflowNodeExecutionRepository: title="Node 1", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Save in random order diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 5146e82e8f..30f51902ef 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -2,19 +2,19 @@ Unit tests for the RepositoryFactory. This module tests the factory pattern implementation for creating repository instances -based on configuration, including error handling and validation. +based on configuration, including error handling. """ from unittest.mock import MagicMock, patch import pytest -from pytest_mock import MockerFixture from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -23,98 +23,30 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom class TestRepositoryFactory: """Test cases for RepositoryFactory.""" - def test_import_class_success(self): + def test_import_string_success(self): """Test successful class import.""" # Test importing a real class class_path = "unittest.mock.MagicMock" - result = DifyCoreRepositoryFactory._import_class(class_path) + result = import_string(class_path) assert result is MagicMock - def test_import_class_invalid_path(self): + def test_import_string_invalid_path(self): """Test import with invalid module path.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalid.module.path") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("invalid.module.path") + assert "No module named" in str(exc_info.value) - def test_import_class_invalid_class_name(self): + def test_import_string_invalid_class_name(self): """Test import with invalid class name.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("unittest.mock.NonExistentClass") + assert "does not define" in str(exc_info.value) - def test_import_class_malformed_path(self): + def test_import_string_malformed_path(self): """Test import with malformed path (no dots).""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalidpath") - assert "Cannot import repository class" in str(exc_info.value) - - def test_validate_repository_interface_success(self): - """Test successful interface validation.""" - - # Create a mock class that implements the required methods - class MockRepository: - def save(self): - pass - - def get_by_id(self): - pass - - # Create a mock interface class - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - # Should not raise an exception when all methods are present - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) - - def test_validate_repository_interface_missing_methods(self): - """Test interface validation with missing methods.""" - - # Create a mock class that's missing required methods - class IncompleteRepository: - def save(self): - pass - - # Missing get_by_id method - - # Create a mock interface that requires both methods - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - def missing_method(self): - pass - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) - assert "does not implement required methods" in str(exc_info.value) - - def test_validate_repository_interface_with_private_methods(self): - """Test that private methods are ignored during interface validation.""" - - class MockRepository: - def save(self): - pass - - def _private_method(self): - pass - - # Create a mock interface with private methods - class MockInterface: - def save(self): - pass - - def _private_method(self): - pass - - # Should not raise exception - private methods should be ignored - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + with pytest.raises(ImportError) as exc_info: + import_string("invalidpath") + assert "doesn't look like a module path" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_execution_repository_success(self, mock_config): @@ -133,11 +65,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -170,34 +99,7 @@ class TestRepositoryFactory: app_id="test-app-id", triggered_from=WorkflowRunTriggeredFrom.APP_RUN, ) - assert "Cannot import repository class" in str(exc_info.value) - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): - """Test WorkflowExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=Account) - - # Mock the import to succeed but validation to fail - mock_repository_class = MagicMock() - mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) - mocker.patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ) - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - assert "Interface validation failed" in str(exc_info.value) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_execution_repository_instantiation_error(self, mock_config): @@ -212,11 +114,8 @@ class TestRepositoryFactory: mock_repository_class = MagicMock() mock_repository_class.side_effect = Exception("Instantiation failed") - # Mock the validation methods to succeed - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, @@ -243,11 +142,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -280,34 +176,7 @@ class TestRepositoryFactory: app_id="test-app-id", triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - assert "Cannot import repository class" in str(exc_info.value) - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): - """Test WorkflowNodeExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=EndUser) - - # Mock the import to succeed but validation to fail - mock_repository_class = MagicMock() - mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) - mocker.patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ) - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, - ) - assert "Interface validation failed" in str(exc_info.value) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): @@ -322,11 +191,8 @@ class TestRepositoryFactory: mock_repository_class = MagicMock() mock_repository_class.side_effect = Exception("Instantiation failed") - # Mock the validation methods to succeed - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, @@ -359,11 +225,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_engine, # Using Engine instead of sessionmaker user=mock_user, diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index 137e8b889d..8b1b9a55bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,5 @@ import uuid from collections.abc import Generator -from datetime import UTC, datetime from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( @@ -15,6 +14,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: @@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: - route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) + route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now()) parallel_id = graph.node_parallel_mapping.get(next_node_id) parallel_start_node_id = None @@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve ) route_node_state.status = RouteNodeState.Status.SUCCESS - route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) + route_node_state.finished_at = naive_utc_now() yield NodeRunSucceededEvent( id=node_execution_id, node_id=next_node_id, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 4866db1fdb..1d2eba1e71 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -1,5 +1,4 @@ import json -from datetime import UTC, datetime from unittest.mock import MagicMock import pytest @@ -23,6 +22,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import Workflow, WorkflowRun @@ -145,8 +145,8 @@ def real_workflow(): workflow.graph = json.dumps(graph_data) workflow.features = json.dumps({"file_upload": {"enabled": False}}) workflow.created_by = "test-user-id" - workflow.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.created_at = naive_utc_now() + workflow.updated_at = naive_utc_now() workflow._environment_variables = "{}" workflow._conversation_variables = "{}" @@ -169,7 +169,7 @@ def real_workflow_run(): workflow_run.outputs = json.dumps({"answer": "test answer"}) workflow_run.created_by_role = CreatorUserRole.ACCOUNT workflow_run.created_by = "test-user-id" - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.created_at = naive_utc_now() return workflow_run @@ -211,7 +211,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -245,7 +245,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -282,7 +282,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -335,7 +335,7 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -366,7 +366,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = datetime.now(UTC).replace(tzinfo=None) + event.start_at = naive_utc_now() # Create a real node execution @@ -379,7 +379,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): node_id="test-node-id", node_type=NodeType.LLM, title="Test Node", - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Pre-populate the cache with the node execution @@ -409,7 +409,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -443,7 +443,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = datetime.now(UTC).replace(tzinfo=None) + event.start_at = naive_utc_now() event.error = "Test error message" # Create a real node execution @@ -457,7 +457,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): node_id="test-node-id", node_type=NodeType.LLM, title="Test Node", - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Pre-populate the cache with the node execution diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index dc09aca5b2..1881ceac26 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -93,16 +93,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus: with ( patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.datetime") as mock_datetime, + patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, ): current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC + mock_naive_utc_now.return_value = current_time yield { "get_document": mock_get_doc, "db_session": mock_db, - "datetime": mock_datetime, + "naive_utc_now": mock_naive_utc_now, "current_time": current_time, } @@ -120,21 +119,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert document.enabled == True assert document.disabled_at is None assert document.disabled_by is None - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): """Helper method to verify document was disabled correctly.""" assert document.enabled == False - assert document.disabled_at == current_time.replace(tzinfo=None) + assert document.disabled_at == current_time assert document.disabled_by == user_id - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): """Helper method to verify document was archived correctly.""" assert document.archived == True - assert document.archived_at == current_time.replace(tzinfo=None) + assert document.archived_at == current_time assert document.archived_by == user_id - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_unarchived(self, document: Mock): """Helper method to verify document was unarchived correctly.""" @@ -430,7 +429,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Verify document attributes were updated correctly self._assert_document_unarchived(archived_doc) - assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None) + assert archived_doc.updated_at == mock_document_service_dependencies["current_time"] # Verify Redis cache was set (because document is enabled) redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) @@ -495,9 +494,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Verify document was unarchived self._assert_document_unarchived(archived_disabled_doc) - assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace( - tzinfo=None - ) + assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"] # Verify no Redis cache was set (document is disabled) redis_mock.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index c4c7579e83..0fc36510b9 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,7 +1,7 @@ from unittest.mock import Mock, patch import pytest -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import BadRequest from services.entities.knowledge_entities.knowledge_entities import MetadataArgs diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index ef4d05c1d9..7f6344f942 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -1,7 +1,7 @@ from unittest.mock import Mock, patch import pytest -from flask_restful import reqparse +from flask_restx import reqparse from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService diff --git a/api/uv.lock b/api/uv.lock index 52eedd9c66..45b020e1dd 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -741,6 +741,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, ] +[[package]] +name = "celery-types" +version = "0.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/d1/0823e71c281e4ad0044e278cf1577d1a68e05f2809424bf94e1614925c5d/celery_types-0.23.0.tar.gz", hash = "sha256:402ed0555aea3cd5e1e6248f4632e4f18eec8edb2435173f9e6dc08449fa101e", size = 31479, upload-time = "2025-03-03T23:56:51.547Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/8b/92bb54dd74d145221c3854aa245c84f4dc04cc9366147496182cec8e88e3/celery_types-0.23.0-py3-none-any.whl", hash = "sha256:0cc495b8d7729891b7e070d0ec8d4906d2373209656a6e8b8276fe1ed306af9a", size = 50189, upload-time = "2025-03-03T23:56:50.458Z" }, +] + [[package]] name = "certifi" version = "2025.6.15" @@ -1254,7 +1266,7 @@ dependencies = [ { name = "flask-login" }, { name = "flask-migrate" }, { name = "flask-orjson" }, - { name = "flask-restful" }, + { name = "flask-restx" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, { name = "gmpy2" }, @@ -1326,6 +1338,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "boto3-stubs" }, + { name = "celery-types" }, { name = "coverage" }, { name = "dotenv-linter" }, { name = "faker" }, @@ -1442,7 +1455,7 @@ requires-dist = [ { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-orjson", specifier = "~=2.0.0" }, - { name = "flask-restful", specifier = "~=0.3.10" }, + { name = "flask-restx", specifier = ">=1.3.0" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=24.11.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, @@ -1514,12 +1527,13 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "boto3-stubs", specifier = ">=1.38.20" }, + { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, - { name = "mypy", specifier = "~=1.16.0" }, + { name = "mypy", specifier = "~=1.17.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, { name = "pytest", specifier = "~=8.3.2" }, { name = "pytest-benchmark", specifier = "~=4.0.0" }, @@ -1875,18 +1889,20 @@ wheels = [ ] [[package]] -name = "flask-restful" -version = "0.3.10" +name = "flask-restx" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aniso8601" }, { name = "flask" }, + { name = "importlib-resources" }, + { name = "jsonschema" }, { name = "pytz" }, - { name = "six" }, + { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/ce/a0a133db616ea47f78a41e15c4c68b9f08cab3df31eb960f61899200a119/Flask-RESTful-0.3.10.tar.gz", hash = "sha256:fe4af2ef0027df8f9b4f797aba20c5566801b6ade995ac63b588abf1a59cec37", size = 110453, upload-time = "2023-05-21T03:58:55.781Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/4c/2e7d84e2b406b47cf3bf730f521efe474977b404ee170d8ea68dc37e6733/flask-restx-1.3.0.tar.gz", hash = "sha256:4f3d3fa7b6191fcc715b18c201a12cd875176f92ba4acc61626ccfd571ee1728", size = 2814072, upload-time = "2023-12-10T14:48:55.575Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/7b/f0b45f0df7d2978e5ae51804bb5939b7897b2ace24306009da0cc34d8d1f/Flask_RESTful-0.3.10-py2.py3-none-any.whl", hash = "sha256:1cf93c535172f112e080b0d4503a8d15f93a48c88bdd36dd87269bdaf405051b", size = 26217, upload-time = "2023-05-21T03:58:54.004Z" }, + { url = "https://files.pythonhosted.org/packages/a5/bf/1907369f2a7ee614dde5152ff8f811159d357e77962aa3f8c2e937f63731/flask_restx-1.3.0-py2.py3-none-any.whl", hash = "sha256:636c56c3fb3f2c1df979e748019f084a938c4da2035a3e535a4673e4fc177691", size = 2798683, upload-time = "2023-12-10T14:48:53.293Z" }, ] [[package]] @@ -3272,28 +3288,28 @@ wheels = [ [[package]] name = "mypy" -version = "1.16.1" +version = "1.17.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mypy-extensions" }, { name = "pathspec" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/81/69/92c7fa98112e4d9eb075a239caa4ef4649ad7d441545ccffbd5e34607cbb/mypy-1.16.1.tar.gz", hash = "sha256:6bd00a0a2094841c5e47e7374bb42b83d64c527a502e3334e1173a0c24437bab", size = 3324747, upload-time = "2025-06-16T16:51:35.145Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/61/ec1245aa1c325cb7a6c0f8570a2eee3bfc40fa90d19b1267f8e50b5c8645/mypy-1.16.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:472e4e4c100062488ec643f6162dd0d5208e33e2f34544e1fc931372e806c0cc", size = 10890557, upload-time = "2025-06-16T16:37:21.421Z" }, - { url = "https://files.pythonhosted.org/packages/6b/bb/6eccc0ba0aa0c7a87df24e73f0ad34170514abd8162eb0c75fd7128171fb/mypy-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea16e2a7d2714277e349e24d19a782a663a34ed60864006e8585db08f8ad1782", size = 10012921, upload-time = "2025-06-16T16:51:28.659Z" }, - { url = "https://files.pythonhosted.org/packages/5f/80/b337a12e2006715f99f529e732c5f6a8c143bb58c92bb142d5ab380963a5/mypy-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08e850ea22adc4d8a4014651575567b0318ede51e8e9fe7a68f25391af699507", size = 11802887, upload-time = "2025-06-16T16:50:53.627Z" }, - { url = "https://files.pythonhosted.org/packages/d9/59/f7af072d09793d581a745a25737c7c0a945760036b16aeb620f658a017af/mypy-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22d76a63a42619bfb90122889b903519149879ddbf2ba4251834727944c8baca", size = 12531658, upload-time = "2025-06-16T16:33:55.002Z" }, - { url = "https://files.pythonhosted.org/packages/82/c4/607672f2d6c0254b94a646cfc45ad589dd71b04aa1f3d642b840f7cce06c/mypy-1.16.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c7ce0662b6b9dc8f4ed86eb7a5d505ee3298c04b40ec13b30e572c0e5ae17c4", size = 12732486, upload-time = "2025-06-16T16:37:03.301Z" }, - { url = "https://files.pythonhosted.org/packages/b6/5e/136555ec1d80df877a707cebf9081bd3a9f397dedc1ab9750518d87489ec/mypy-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:211287e98e05352a2e1d4e8759c5490925a7c784ddc84207f4714822f8cf99b6", size = 9479482, upload-time = "2025-06-16T16:47:37.48Z" }, - { url = "https://files.pythonhosted.org/packages/b4/d6/39482e5fcc724c15bf6280ff5806548c7185e0c090712a3736ed4d07e8b7/mypy-1.16.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:af4792433f09575d9eeca5c63d7d90ca4aeceda9d8355e136f80f8967639183d", size = 11066493, upload-time = "2025-06-16T16:47:01.683Z" }, - { url = "https://files.pythonhosted.org/packages/e6/e5/26c347890efc6b757f4d5bb83f4a0cf5958b8cf49c938ac99b8b72b420a6/mypy-1.16.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:66df38405fd8466ce3517eda1f6640611a0b8e70895e2a9462d1d4323c5eb4b9", size = 10081687, upload-time = "2025-06-16T16:48:19.367Z" }, - { url = "https://files.pythonhosted.org/packages/44/c7/b5cb264c97b86914487d6a24bd8688c0172e37ec0f43e93b9691cae9468b/mypy-1.16.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44e7acddb3c48bd2713994d098729494117803616e116032af192871aed80b79", size = 11839723, upload-time = "2025-06-16T16:49:20.912Z" }, - { url = "https://files.pythonhosted.org/packages/15/f8/491997a9b8a554204f834ed4816bda813aefda31cf873bb099deee3c9a99/mypy-1.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ab5eca37b50188163fa7c1b73c685ac66c4e9bdee4a85c9adac0e91d8895e15", size = 12722980, upload-time = "2025-06-16T16:37:40.929Z" }, - { url = "https://files.pythonhosted.org/packages/df/f0/2bd41e174b5fd93bc9de9a28e4fb673113633b8a7f3a607fa4a73595e468/mypy-1.16.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dedb6229b2c9086247e21a83c309754b9058b438704ad2f6807f0d8227f6ebdd", size = 12903328, upload-time = "2025-06-16T16:34:35.099Z" }, - { url = "https://files.pythonhosted.org/packages/61/81/5572108a7bec2c46b8aff7e9b524f371fe6ab5efb534d38d6b37b5490da8/mypy-1.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:1f0435cf920e287ff68af3d10a118a73f212deb2ce087619eb4e648116d1fe9b", size = 9562321, upload-time = "2025-06-16T16:48:58.823Z" }, - { url = "https://files.pythonhosted.org/packages/cf/d3/53e684e78e07c1a2bf7105715e5edd09ce951fc3f47cf9ed095ec1b7a037/mypy-1.16.1-py3-none-any.whl", hash = "sha256:5fc2ac4027d0ef28d6ba69a0343737a23c4d1b83672bf38d1fe237bdc0643b37", size = 2265923, upload-time = "2025-06-16T16:48:02.366Z" }, + { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" }, + { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" }, + { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" }, + { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" }, + { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" }, + { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" }, + { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" }, + { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" }, + { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" }, + { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" }, + { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" }, + { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" }, + { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 826b7b9fe6..711898016e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -215,6 +215,8 @@ DB_DATABASE=dify # The size of the database connection pool. # The default is 30 connections, which can be appropriately increased. SQLALCHEMY_POOL_SIZE=30 +# The default is 10 connections, which allows temporary overflow beyond the pool size. +SQLALCHEMY_MAX_OVERFLOW=10 # Database connection pool recycling time, the default is 3600 seconds. SQLALCHEMY_POOL_RECYCLE=3600 # Whether to print SQL, default is false. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 0c352e4658..d3b75d93af 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -57,6 +57,7 @@ x-shared-env: &shared-api-worker-env DB_PORT: ${DB_PORT:-5432} DB_DATABASE: ${DB_DATABASE:-dify} SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} + SQLALCHEMY_MAX_OVERFLOW: ${SQLALCHEMY_MAX_OVERFLOW:-10} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false} diff --git a/web/Dockerfile b/web/Dockerfile index d284efca87..1376dec749 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -34,7 +34,7 @@ COPY --from=packages /app/web/ . COPY . . ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build +RUN pnpm build:docker # production stage diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 861020545d..4ba451452c 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -21,6 +21,10 @@ import Checkbox from '@/app/components/base/checkbox' import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import { SimpleSelect } from '@/app/components/base/select' +import Textarea from '@/app/components/base/textarea' +import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader' +import { TransferMethod } from '@/types/app' +import type { FileEntity } from '@/app/components/base/file-uploader/types' const TEXT_MAX_LENGTH = 256 @@ -82,6 +86,8 @@ const ConfigModal: FC = ({ return () => { const newPayload = produce(tempPayload, (draft) => { draft.type = type + // Clear default value when switching types + draft.default = undefined if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { (Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => { if (key !== 'max_length') @@ -234,6 +240,41 @@ const ConfigModal: FC = ({ )} + + {/* Default value for text input */} + {type === InputVarType.textInput && ( + + handlePayloadChange('default')(e.target.value || undefined)} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} + /> + + )} + + {/* Default value for paragraph */} + {type === InputVarType.paragraph && ( + +