Merge branch 'main' into feat/memory-orchestration-be

This commit is contained in:
Stream 2025-10-09 15:01:03 +08:00
commit c367f80ec5
No known key found for this signature in database
GPG Key ID: 033728094B100D70
318 changed files with 7286 additions and 3758 deletions

View File

@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@ -26,7 +26,6 @@ prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment

View File

@ -40,18 +40,18 @@
<p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.

View File

@ -427,8 +427,8 @@ CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000

View File

@ -80,10 +80,10 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
```bash
uv run celery -A app.celery beat

View File

@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
default=80000,
default=400_000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@ -582,6 +582,11 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",

View File

@ -1,4 +1,5 @@
from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
_doc_extensions: list[str]
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
_doc_extensions = {
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt")
_doc_extensions.add("ppt")
else:
_doc_extensions = [
_doc_extensions = {
"txt",
"markdown",
"md",
@ -37,5 +53,5 @@ else:
"csv",
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)

View File

@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from libs.validators import validate_description_length
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/apps")
class AppListApi(Resource):
@api.doc("list_apps")
@ -138,7 +133,7 @@ class AppListApi(Resource):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
@ -219,7 +214,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -297,7 +292,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")

View File

@ -1,4 +1,5 @@
import flask_restx
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@ -30,24 +31,20 @@ from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name):
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/datasets")
class DatasetListApi(Resource):
@api.doc("get_datasets")
@ -92,7 +89,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
@ -147,7 +144,7 @@ class DatasetListApi(Resource):
)
parser.add_argument(
"description",
type=_validate_description_length,
type=validate_description_length,
nullable=True,
required=False,
default="",
@ -192,7 +189,7 @@ class DatasetListApi(Resource):
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
account=cast(Account, current_user),
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@ -224,7 +221,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -288,7 +285,7 @@ class DatasetApi(Resource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
@ -369,7 +366,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
@ -688,7 +685,7 @@ class DatasetApiKeyApi(Resource):
)
if current_key_count >= self.max_keys:
flask_restx.abort(
api.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
@ -733,7 +730,7 @@ class DatasetApiDeleteApi(Resource):
)
if key is None:
flask_restx.abort(404, message="API key not found")
api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@ -55,6 +55,7 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -418,7 +419,9 @@ class DatasetInitApi(Resource):
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
tenant_id=current_user.current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -452,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
@ -514,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
@ -753,7 +756,7 @@ class DocumentApi(DocumentResource):
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -1073,7 +1076,9 @@ class DocumentRenameApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()

View File

@ -392,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
str(job_id),
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_user.id,
)
except Exception as e:
return {"error": str(e)}, 500
@ -468,7 +473,8 @@ class ChildChunkAddApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -557,7 +563,8 @@ class ChildChunkAddApi(Resource):
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
@ -674,9 +681,8 @@ class ChildChunkUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
content = args["content"]
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,3 +1,5 @@
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse
@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name):
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
account=cast(Account, current_user),
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)

View File

@ -1,10 +1,11 @@
import logging
from typing import cast
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@ -20,6 +21,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
account=cast(Account, current_user),
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,

View File

@ -62,6 +62,7 @@ class DatasetMetadataApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
name = args["name"]
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -70,7 +71,7 @@ class DatasetMetadataApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200
@setup_required

View File

@ -1,4 +1,3 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
@ -11,6 +10,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required

View File

@ -20,13 +20,13 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -76,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",
@ -133,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",

View File

@ -1,10 +1,10 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restx import Resource, marshal, reqparse # type: ignore
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
return import_info, 201
@console_ns.route("/rag/pipeline/empty-dataset")
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@ -1,24 +1,22 @@
import logging
from typing import Any, NoReturn
from typing import NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
@ -104,13 +76,14 @@ def _api_prerequisite(f):
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
return None
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
return draft_vars
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, pipeline: Pipeline):
@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@ -66,6 +67,7 @@ class RagPipelineImportApi(Resource):
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@ -90,6 +92,7 @@ class RagPipelineImportConfirmApi(Resource):
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource):
)
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console import console_ns
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
@ -50,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource):
@setup_required
@login_required
@ -147,6 +148,7 @@ class DraftRagPipelineApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@ -181,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@setup_required
@login_required
@ -215,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@setup_required
@login_required
@ -249,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@setup_required
@login_required
@ -369,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required
@login_required
@ -411,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@setup_required
@login_required
@ -453,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@setup_required
@login_required
@ -486,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
return workflow_node_execution
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
class RagPipelineTaskStopApi(Resource):
@setup_required
@login_required
@ -504,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
return {"result": "success"}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
class PublishedRagPipelineApi(Resource):
@setup_required
@login_required
@ -559,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required
@login_required
@ -577,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@setup_required
@login_required
@ -608,6 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@setup_required
@login_required
@ -656,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@setup_required
@login_required
@ -713,6 +727,7 @@ class RagPipelineByIdApi(Resource):
return workflow
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@setup_required
@login_required
@ -738,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@setup_required
@login_required
@ -763,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@setup_required
@login_required
@ -788,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@setup_required
@login_required
@ -814,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@setup_required
@login_required
@ -835,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
class RagPipelineWorkflowRunDetailApi(Resource):
@setup_required
@login_required
@ -853,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
return workflow_run
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@setup_required
@login_required
@ -876,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
return {"data": node_executions}
@console_ns.route("/rag/pipelines/datasource-plugins")
class DatasourceListApi(Resource):
@setup_required
@login_required
@ -891,6 +913,7 @@ class DatasourceListApi(Resource):
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
class RagPipelineWorkflowLastRunApi(Resource):
@setup_required
@login_required
@ -912,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
return node_exec
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@ -929,6 +953,7 @@ class RagPipelineTransformApi(Resource):
return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@setup_required
@login_required
@ -958,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
return workflow_node_execution
@console_ns.route("/rag/pipelines/recommended-plugins")
class RagPipelineRecommendedPluginApi(Resource):
@setup_required
@login_required
@ -966,114 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
return recommended_plugins
api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
)
api.add_resource(
DraftRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
)
api.add_resource(
PublishedRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
)
api.add_resource(
RagPipelineTaskStopApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunIterationNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunLoopNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
)
api.add_resource(
PublishedAllRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows",
)
api.add_resource(
DefaultRagPipelineBlockConfigsApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultRagPipelineBlockConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
RagPipelineByIdApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
)
api.add_resource(
RagPipelineWorkflowRunListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
)
api.add_resource(
RagPipelineWorkflowRunDetailApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
)
api.add_resource(
RagPipelineWorkflowRunNodeExecutionListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasource-plugins",
)
api.add_resource(
PublishedRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
)
api.add_resource(
PublishedRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
)
api.add_resource(
DraftRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
)
api.add_resource(
DraftRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
)
api.add_resource(
RagPipelineWorkflowLastRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(
RagPipelineTransformApi,
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
)
api.add_resource(
RagPipelineDatasourceVariableApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
)
api.add_resource(
RagPipelineRecommendedPluginApi,
"/rag/pipelines/recommended-plugins",
)

View File

@ -3,7 +3,7 @@ from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
@ -17,6 +17,7 @@ def _validate_name(name):
return name
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@login_required
@ -52,6 +53,7 @@ class TagListApi(Resource):
return response, 200
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
return 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@setup_required
@login_required
@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@setup_required
@login_required
@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
TagService.delete_tag_binding(args)
return {"result": "success"}, 200
api.add_resource(TagListApi, "/tags")
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

View File

@ -1,10 +1,10 @@
from typing import Literal
from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
import services
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import (
@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID
@ -31,12 +32,6 @@ def _validate_name(name):
return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
# Define parsers for dataset operations
dataset_create_parser = reqparse.RequestParser()
dataset_create_parser.add_argument(
@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
)
dataset_create_parser.add_argument(
"description",
type=_validate_description_length,
type=validate_description_length,
nullable=True,
required=False,
default="",
@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
type=_validate_name,
)
dataset_update_parser.add_argument(
"description", location="json", store_missing=False, type=_validate_description_length
"description", location="json", store_missing=False, type=validate_description_length
)
dataset_update_parser.add_argument(
"indexing_technique",
@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
"""Resource for creating datasets."""
args = dataset_create_parser.parse_args()
if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
)
try:
@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if data.get("indexing_technique") == "high_quality":
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names:
data["embedding_available"] = True
else:
@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
data["embedding_available"] = True
# force update search method to keyword_search if indexing_technique is economic
data["retrieval_model_dict"]["search_method"] = "keyword_search"
retrieval_model_dict = data.get("retrieval_model_dict")
if retrieval_model_dict:
retrieval_model_dict["search_method"] = "keyword_search"
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
embedding_model_provider = data.get("embedding_model_provider")
embedding_model = data.get("embedding_model")
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
if (
data.get("retrieval_model")
and data.get("retrieval_model").get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
args = tag_update_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id"))
tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
if not current_user.has_edit_permission:
raise Forbidden()
args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id"))
TagService.delete_tag(args["tag_id"])
return 204

View File

@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
)
if not current_user:
@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
)
# indexing_technique is already set in dataset since this is an update

View File

@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata")

View File

@ -1,4 +1,5 @@
import uuid
from typing import Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@ -74,6 +75,9 @@ class DatasetConfigManager:
return None
query_variable = config.get("dataset_query_variable")
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
if dataset_configs["retrieval_model"] == "single":
return DatasetEntity(
dataset_ids=dataset_ids,
@ -82,18 +86,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
else:
score_threshold_val = dataset_configs.get("score_threshold")
reranking_model_val = dataset_configs.get("reranking_model")
weights_val = dataset_configs.get("weights")
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
@ -101,22 +110,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold")
if dataset_configs.get("score_threshold_enabled", False)
top_k=int(dataset_configs.get("top_k", 4)),
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
@ -134,18 +144,17 @@ class DatasetConfigManager:
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
if "dataset_configs" not in config or not config.get("dataset_configs"):
config["dataset_configs"] = {}
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if not config["dataset_configs"].get("datasets"):
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@ -166,8 +175,8 @@ class DatasetConfigManager:
:param config: app model config args
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if "agent_mode" not in config or not config.get("agent_mode"):
config["agent_mode"] = {}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -180,19 +189,22 @@ class DatasetConfigManager:
raise ValueError("enabled in agent_mode must be of boolean type")
# tools
if not config["agent_mode"].get("tools"):
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
# strategy
if not config["agent_mode"].get("strategy"):
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]:
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]
if key == "dataset":
# old style, use tool name as key
@ -217,7 +229,7 @@ class DatasetConfigManager:
has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion

View File

@ -1,9 +1,11 @@
import logging
import queue
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta
from configs import dify_config
@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
@ -35,9 +39,8 @@ class AppQueueManager:
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
@ -79,9 +82,21 @@ class AppQueueManager:
Stop listen to queue
:return:
"""
self._clear_task_belong_cache()
self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom):
def _clear_task_belong_cache(self) -> None:
"""
Remove the task belong cache key once listening is finished.
"""
try:
redis_client.delete(self._task_belong_cache_key)
except RedisError:
logger.exception(
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
)
def publish_error(self, e, pub_from: PublishFrom) -> None:
"""
Publish error
:param e: error

View File

@ -107,7 +107,6 @@ class MessageCycleManager:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
db.session.merge(conversation)
db.session.commit()
db.session.close()

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional
from openai import BaseModel
from pydantic import Field
from pydantic import BaseModel, Field
# Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content.
"""
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
data: str
@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
class ImagePromptMessageContent(MultiModalPromptMessageContent):
@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
PromptMessageContentUnionTypes = Annotated[

View File

@ -15,7 +15,7 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
tokens = _tokenizer.encode(text) # type: ignore
return len(tokens)
@staticmethod

View File

@ -196,15 +196,15 @@ def jsonable_encoder(
return encoder(obj)
try:
data = dict(obj)
data = dict(obj) # type: ignore
except Exception as e:
errors: list[Exception] = []
errors.append(e)
try:
data = vars(obj)
data = vars(obj) # type: ignore
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
raise ValueError(str(errors)) from e
return jsonable_encoder(
data,
by_alias=by_alias,

View File

@ -3,7 +3,8 @@ from dataclasses import dataclass
from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field

View File

@ -155,7 +155,10 @@ class OpsTraceManager:
if key in tracing_config:
if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get(key, tracing_config[key])
if current_trace_config:
new_config[key] = current_trace_config.get(key, tracing_config[key])
else:
new_config[key] = tracing_config[key]
else:
# Otherwise, encrypt the key
new_config[key] = encrypt_token(tenant_id, tracing_config[key])

View File

@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance):
self,
):
try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
project_url = f"https://wandb.ai/{project_identifier}"
return project_url
except Exception as e:
logger.debug("Weave get run url failed: %s", str(e))
@ -424,7 +425,23 @@ class WeaveDataTrace(BaseTraceInstance):
raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
inputs = run_data.inputs
if inputs is None:
inputs = {}
elif not isinstance(inputs, dict):
inputs = {"inputs": str(inputs)}
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
call = self.weave_client.create_call(
op=run_data.op,
inputs=inputs,
attributes=attributes,
)
self.calls[run_data.id] = call
if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id
@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id)
if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
exception = Exception(run_data.exception) if run_data.exception else None
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
else:
raise ValueError(f"Call with id {run_data.id} not found")

View File

@ -106,7 +106,9 @@ class RetrievalService:
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
@ -143,6 +145,40 @@ class RetrievalService:
)
return all_documents
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
return unique_documents
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:

View File

@ -1,7 +1,6 @@
from typing import Any
from openai import BaseModel
from pydantic import Field
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom

View File

@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def run(self, query: str) -> str:
"""Use the tool."""
return self._run(query)
@abstractmethod
def _run(self, query: str) -> str:
"""Use the tool.

View File

@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool):
yield self.create_text_message(text="please input query")
else:
# invoke dataset retriever tool
result = self.retrieval_tool._run(query=query)
result = self.retrieval_tool.run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(

View File

@ -2,6 +2,7 @@ import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any
from flask import request
from requests import get
@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser:
if "allOf" in prop_dict:
del prop_dict["allOf"]
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
warning = warning or {}
"""
parse swagger to openapi
@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser:
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
converted_openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser:
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
converted_openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser:
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
if "definitions" in swagger:
for name, definition in swagger["definitions"].items():
converted_openapi["components"]["schemas"][name] = definition
return openapi
return converted_openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(

View File

@ -191,11 +191,22 @@ class VariablePool(BaseModel):
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
"""
Get a nested attribute from a dictionary-like object.
Args:
obj: The dictionary-like object to search.
attr: The key to look up.
Returns:
Segment | None:
The corresponding Segment built from the attribute value if the key exists,
otherwise None.
"""
if not isinstance(obj, dict) or attr not in obj:
return None
return obj.get(attr)
return variable_factory.build_segment(obj.get(attr))
def remove(self, selector: Sequence[str], /):
"""

View File

@ -20,6 +20,7 @@ class ModelInvokeCompletedEvent(NodeEventBase):
usage: LLMUsage
finish_reason: str | None = None
reasoning_content: str | None = None
structured_output: dict | None = None
class RunRetryEvent(NodeEventBase):

View File

@ -87,7 +87,7 @@ class Executor:
node_data.authorization.config.api_key
).text
self.url: str = node_data.url
self.url = node_data.url
self.method = node_data.method
self.auth = node_data.authorization
self.timeout = timeout
@ -349,11 +349,10 @@ class Executor:
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"ssl_verify": self.ssl_verify,
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue

View File

@ -165,6 +165,8 @@ class HttpRequestNode(Node):
body_type = typed_node_data.body.type
data = typed_node_data.body.data
match body_type:
case "none":
pass
case "binary":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")

View File

@ -83,7 +83,7 @@ class IfElseNode(Node):
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],

View File

@ -10,6 +10,8 @@ from typing_extensions import TypeIs
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
@ -217,6 +219,13 @@ class IterationNode(Node):
graph_engine=graph_engine,
)
# Sync conversation variables after each iteration completes
self._sync_conversation_variables_from_snapshot(
self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
)
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
@ -235,7 +244,10 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
int,
] = {}
for index, item in enumerate(iterator_list_value):
yield IterationNextEvent(index=index)
future = executor.submit(
@ -252,7 +264,7 @@ class IterationNode(Node):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used = result
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
# Update outputs at the correct index
outputs[index] = output_value
@ -264,6 +276,9 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
@ -288,7 +303,7 @@ class IterationNode(Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -307,8 +322,17 @@ class IterationNode(Node):
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
return (
iter_start_at,
events,
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
)
def _handle_iteration_success(
self,
@ -430,6 +454,23 @@ class IterationNode(Node):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
current_keys = set(parent_conversations.keys())
snapshot_keys = set(snapshot.keys())
for removed_key in current_keys - snapshot_keys:
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
for name, variable in snapshot.items():
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
def _append_iteration_info_to_event(
self,
event: GraphNodeEventBase,

View File

@ -136,6 +136,11 @@ class KnowledgeIndexNode(Node):
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
doc_id_value = document.id
ds_id_value = dataset.id
dataset_name_value = dataset.name
document_name_value = document.name
created_at_value = document.created_at
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
@ -161,16 +166,16 @@ class KnowledgeIndexNode(Node):
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
)
.scalar()
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
).update(
{
DocumentSegment.status: "completed",
@ -182,13 +187,13 @@ class KnowledgeIndexNode(Node):
db.session.commit()
return {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
"batch": batch.value,
"document_id": document.id,
"document_name": document.name,
"created_at": document.created_at.timestamp(),
"display_status": document.indexing_status,
"document_id": doc_id_value,
"document_name": document_name_value,
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:

View File

@ -107,7 +107,7 @@ class KnowledgeRetrievalNode(Node):
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
self._file_outputs = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(

View File

@ -161,6 +161,8 @@ class ListOperatorNode(Node):
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
elif isinstance(condition.value, bool):
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
else:
value = condition.value
filter_func = _get_file_filter_func(

View File

@ -46,7 +46,7 @@ class LLMFileSaver(tp.Protocol):
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
pass
raise NotImplementedError()
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
@ -56,7 +56,7 @@ class LLMFileSaver(tp.Protocol):
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
pass
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]

View File

@ -27,6 +27,7 @@ from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
@ -134,7 +135,7 @@ class LLMNode(Node):
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
self._file_outputs = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -172,6 +173,7 @@ class LLMNode(Node):
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
result_text = ""
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
@ -285,6 +287,13 @@ class LLMNode(Node):
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
# Process structured output if available from the event.
structured_output = (
LLMStructuredOutput(structured_output=event.structured_output)
if event.structured_output
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
@ -1060,7 +1069,7 @@ class LLMNode(Node):
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list["File"],
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -1091,6 +1100,8 @@ class LLMNode(Node):
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None),
)
@staticmethod

View File

@ -179,6 +179,6 @@ CHAT_EXAMPLE = [
"required": ["food"],
},
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
},
]

View File

@ -68,7 +68,7 @@ class QuestionClassifierNode(Node):
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
self._file_outputs = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -111,9 +111,9 @@ class QuestionClassifierNode(Node):
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=node_data.model,
model_instance, model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id,
node_data_model=node_data.model,
)
# fetch memory
memory = llm_utils.fetch_memory(

View File

@ -1,7 +1,7 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -9,7 +9,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node):

View File

@ -416,4 +416,8 @@ class WorkflowEntry:
# append variable and value to variable pool
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
# In single run, the input_value is set as the LLM's structured output value within the variable_pool.
if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output":
input_value = {variable_key_list[1]: input_value}
variable_key_list = variable_key_list[0:1]
variable_pool.add([variable_node_id] + variable_key_list, input_value)

View File

@ -10,14 +10,14 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
@app.after_request
def after_request(response):
def after_request(response): # pyright: ignore[reportUnusedFunction]
"""Add Version headers to the response."""
response.headers.add("X-Version", dify_config.project.version)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response
@app.route("/health")
def health():
def health(): # pyright: ignore[reportUnusedFunction]
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
status=200,
@ -25,7 +25,7 @@ def init_app(app: DifyApp):
)
@app.route("/threads")
def threads():
def threads(): # pyright: ignore[reportUnusedFunction]
num_threads = threading.active_count()
threads = threading.enumerate()
@ -50,7 +50,7 @@ def init_app(app: DifyApp):
}
@app.route("/db-pool-stat")
def pool_stat():
def pool_stat(): # pyright: ignore[reportUnusedFunction]
from extensions.ext_database import db
engine = db.engine

View File

@ -145,6 +145,7 @@ def init_app(app: DifyApp) -> Celery:
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
beat_schedule["check_upgradable_plugin_task"] = {
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
"schedule": crontab(minute="*/15"),

View File

@ -10,7 +10,7 @@ from models.engine import db
logger = logging.getLogger(__name__)
# Global flag to avoid duplicate registration of event listener
_GEVENT_COMPATIBILITY_SETUP: bool = False
_gevent_compatibility_setup: bool = False
def _safe_rollback(connection):
@ -26,14 +26,14 @@ def _safe_rollback(connection):
def _setup_gevent_compatibility():
global _GEVENT_COMPATIBILITY_SETUP # pylint: disable=global-statement
global _gevent_compatibility_setup # pylint: disable=global-statement
# Avoid duplicate registration
if _GEVENT_COMPATIBILITY_SETUP:
if _gevent_compatibility_setup:
return
@event.listens_for(Pool, "reset")
def _safe_reset(dbapi_connection, connection_record, reset_state): # pylint: disable=unused-argument
def _safe_reset(dbapi_connection, connection_record, reset_state): # pyright: ignore[reportUnusedFunction]
if reset_state.terminate_only:
return
@ -47,7 +47,7 @@ def _setup_gevent_compatibility():
except (AttributeError, ImportError):
_safe_rollback(dbapi_connection)
_GEVENT_COMPATIBILITY_SETUP = True
_gevent_compatibility_setup = True
def init_app(app: DifyApp):

View File

@ -2,4 +2,4 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from events import event_handlers # noqa: F401
from events import event_handlers # noqa: F401 # pyright: ignore[reportUnusedImport]

View File

@ -136,6 +136,7 @@ def init_app(app: DifyApp):
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
@ -238,6 +239,7 @@ def init_app(app: DifyApp):
init_sqlalchemy_instrumentor(app)
RedisInstrumentor().instrument()
RequestsInstrumentor().instrument()
HTTPXClientInstrumentor().instrument()
atexit.register(shutdown_tracer)

View File

@ -4,7 +4,6 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import openai
import sentry_sdk
from langfuse import parse_error # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
@ -28,7 +27,6 @@ def init_app(app: DifyApp):
HTTPException,
ValueError,
FileNotFoundError,
openai.APIStatusError,
InvokeRateLimitError,
parse_error.defaultErrorResponse,
],

View File

@ -33,7 +33,9 @@ class AliyunOssStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data: bytes = obj.read()
data = obj.read()
if not isinstance(data, bytes):
return b""
return data
def load_stream(self, filename: str) -> Generator:

View File

@ -39,10 +39,10 @@ class AwsS3Storage(BaseStorage):
self.client.head_bucket(Bucket=self.bucket_name)
except ClientError as e:
# if bucket not exists, create it
if e.response["Error"]["Code"] == "404":
if e.response.get("Error", {}).get("Code") == "404":
self.client.create_bucket(Bucket=self.bucket_name)
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
elif e.response["Error"]["Code"] == "403":
elif e.response.get("Error", {}).get("Code") == "403":
pass
else:
# other error, raise exception
@ -55,7 +55,7 @@ class AwsS3Storage(BaseStorage):
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
@ -66,7 +66,7 @@ class AwsS3Storage(BaseStorage):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("file not found")
elif "reached max retries" in str(ex):
raise ValueError("please do not request the same file too frequently")

View File

@ -27,24 +27,38 @@ class AzureBlobStorage(BaseStorage):
self.credential = None
def save(self, filename, data):
if not self.bucket_name:
return
client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
client = self._sync_client()
blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data: bytes = blob.download_blob().readall()
data = blob.download_blob().readall()
if not isinstance(data, bytes):
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
return data
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob.download_blob()
yield from blob_data.chunks()
def download(self, filename, target_filepath):
if not self.bucket_name:
return
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
@ -53,12 +67,18 @@ class AzureBlobStorage(BaseStorage):
blob_data.readinto(my_blob)
def exists(self, filename):
if not self.bucket_name:
return False
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
def delete(self, filename):
if not self.bucket_name:
return
client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)

View File

@ -430,7 +430,7 @@ class ClickZettaVolumeStorage(BaseStorage):
rows = self._execute_sql(sql, fetch=True)
exists = len(rows) > 0
exists = len(rows) > 0 if rows else False
logger.debug("File %s exists check: %s", filename, exists)
return exists
except Exception as e:
@ -509,16 +509,17 @@ class ClickZettaVolumeStorage(BaseStorage):
rows = self._execute_sql(sql, fetch=True)
result = []
for row in rows:
file_path = row[0] # relative_path column
if rows:
for row in rows:
file_path = row[0] # relative_path column
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
logger.debug("Scanned %d items in path %s", len(result), path)
return result

View File

@ -439,6 +439,11 @@ class VolumePermissionManager:
self._permission_cache.clear()
logger.debug("Permission cache cleared")
@property
def volume_type(self) -> str | None:
"""Get the volume type."""
return self._volume_type
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
"""Get permission summary
@ -632,13 +637,13 @@ def check_volume_permission(permission_manager: VolumePermissionManager, operati
VolumePermissionError: If no permission
"""
if not permission_manager.validate_operation(operation, dataset_id):
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
if dataset_id:
error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError(
error_message,
operation=operation,
volume_type=permission_manager._volume_type or "unknown",
volume_type=permission_manager.volume_type or "unknown",
dataset_id=dataset_id,
)

View File

@ -35,12 +35,16 @@ class GoogleCloudStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
data: bytes = blob.download_as_bytes()
return data
def load_stream(self, filename: str) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
with blob.open(mode="rb") as blob_stream:
while chunk := blob_stream.read(4096):
yield chunk
@ -48,6 +52,8 @@ class GoogleCloudStorage(BaseStorage):
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
blob.download_to_filename(target_filepath)
def exists(self, filename):

View File

@ -45,7 +45,7 @@ class HuaweiObsStorage(BaseStorage):
def _get_meta(self, filename):
res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
if res.status < 300:
if res and res.status and res.status < 300:
return res
else:
return None

View File

@ -3,9 +3,9 @@ import os
from collections.abc import Generator
from pathlib import Path
import opendal
from dotenv import dotenv_values
from opendal import Operator
from opendal.layers import RetryLayer
from extensions.storage.base_storage import BaseStorage
@ -35,7 +35,7 @@ class OpenDALStorage(BaseStorage):
root = kwargs.get("root", "storage")
Path(root).mkdir(parents=True, exist_ok=True)
retry_layer = RetryLayer(max_times=3, factor=2.0, jitter=True)
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
self.op = Operator(scheme=scheme, **kwargs).layer(retry_layer)
logger.debug("opendal operator created with scheme %s", scheme)
logger.debug("added retry layer to opendal operator")

View File

@ -29,7 +29,7 @@ class OracleOCIStorage(BaseStorage):
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
@ -40,7 +40,7 @@ class OracleOCIStorage(BaseStorage):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise

View File

@ -46,13 +46,13 @@ class SupabaseStorage(BaseStorage):
Path(target_filepath).write_bytes(result)
def exists(self, filename):
result = self.client.storage.from_(self.bucket_name).list(filename)
if result.count() > 0:
result = self.client.storage.from_(self.bucket_name).list(path=filename)
if len(result) > 0:
return True
return False
def delete(self, filename):
self.client.storage.from_(self.bucket_name).remove(filename)
self.client.storage.from_(self.bucket_name).remove([filename])
def bucket_exists(self):
buckets = self.client.storage.list_buckets()

View File

@ -11,6 +11,14 @@ class VolcengineTosStorage(BaseStorage):
def __init__(self):
super().__init__()
if not dify_config.VOLCENGINE_TOS_ACCESS_KEY:
raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set")
if not dify_config.VOLCENGINE_TOS_SECRET_KEY:
raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set")
if not dify_config.VOLCENGINE_TOS_ENDPOINT:
raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set")
if not dify_config.VOLCENGINE_TOS_REGION:
raise ValueError("VOLCENGINE_TOS_REGION is not set")
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
self.client = tos.TosClientV2(
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
@ -20,27 +28,39 @@ class VolcengineTosStorage(BaseStorage):
)
def save(self, filename, data):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
if not isinstance(data, bytes):
raise TypeError(f"Expected bytes, got {type(data).__name__}")
return data
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
response = self.client.get_object(bucket=self.bucket_name, key=filename)
while chunk := response.read(4096):
yield chunk
def download(self, filename, target_filepath):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
def exists(self, filename):
if not self.bucket_name:
return False
res = self.client.head_object(bucket=self.bucket_name, key=filename)
if res.status_code != 200:
return False
return True
def delete(self, filename):
if not self.bucket_name:
return
self.client.delete_object(bucket=self.bucket_name, key=filename)

View File

@ -146,6 +146,8 @@ def build_segment(value: Any, /) -> Segment:
# below
if value is None:
return NoneSegment()
if isinstance(value, Segment):
return value
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, bool):

View File

@ -0,0 +1,14 @@
def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]:
"""
Convert a list or set of strings to a set containing both lower and upper case versions of each string.
Args:
inputs (list[str] | set[str]): A list or set of strings to be converted.
Returns:
set[str]: A set containing both lower and upper case versions of each string.
"""
if not inputs:
return set()
else:
return {case for s in inputs if s for case in (s.lower(), s.upper())}

View File

@ -94,7 +94,7 @@ def register_external_error_handlers(api: Api):
got_request_exception.send(current_app, exception=e)
status_code = 500
data = getattr(e, "data", {"message": http_status_message(status_code)})
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, dict):

View File

@ -27,7 +27,7 @@ import gmpy2 # type: ignore
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
from Crypto.Util.py3compat import _copy_bytes, bord
from Crypto.Util.py3compat import bord
from Crypto.Util.strxor import strxor
@ -72,7 +72,7 @@ class PKCS1OAepCipher:
else:
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
self._label = _copy_bytes(None, None, label)
self._label = bytes(label)
self._randfunc = randfunc
def can_encrypt(self):
@ -120,7 +120,7 @@ class PKCS1OAepCipher:
# Step 2b
ps = b"\x00" * ps_len
# Step 2c
db = lHash + ps + b"\x01" + _copy_bytes(None, None, message)
db = lHash + ps + b"\x01" + bytes(message)
# Step 2d
ros = self._randfunc(hLen)
# Step 2e

View File

@ -14,7 +14,7 @@ class SendGridClient:
def send(self, mail: dict):
logger.debug("Sending email with SendGrid")
_to = ""
try:
_to = mail["to"]
@ -28,7 +28,7 @@ class SendGridClient:
content = Content("text/html", mail["html"])
sg_mail = Mail(from_email, to_email, subject, content)
mail_json = sg_mail.get()
response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable]
response = sg.client.mail.send.post(request_body=mail_json) # type: ignore
logger.debug(response.status_code)
logger.debug(response.body)
logger.debug(response.headers)

5
api/libs/validators.py Normal file
View File

@ -0,0 +1,5 @@
def validate_description_length(description: str | None) -> str | None:
"""Validate description length."""
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description

View File

@ -1,11 +1,10 @@
[project]
name = "dify-api"
version = "1.9.0"
version = "1.9.1"
requires-python = ">=3.11,<3.13"
dependencies = [
"arize-phoenix-otel~=0.9.2",
"authlib==1.6.4",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
"boto3==1.35.99",
@ -34,10 +33,8 @@ dependencies = [
"json-repair>=0.41.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"mailchimp-transactional~=1.0.50",
"markdown~=3.5.1",
"numpy~=1.26.4",
"openai~=1.61.0",
"openpyxl~=3.1.5",
"opik~=1.7.25",
"opentelemetry-api==1.27.0",
@ -49,6 +46,7 @@ dependencies = [
"opentelemetry-instrumentation==0.48b0",
"opentelemetry-instrumentation-celery==0.48b0",
"opentelemetry-instrumentation-flask==0.48b0",
"opentelemetry-instrumentation-httpx==0.48b0",
"opentelemetry-instrumentation-redis==0.48b0",
"opentelemetry-instrumentation-requests==0.48b0",
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
@ -60,7 +58,6 @@ dependencies = [
"opentelemetry-semantic-conventions==0.48b0",
"opentelemetry-util-http==0.48b0",
"pandas[excel,output-formatting,performance]~=2.2.2",
"pandoc~=2.4",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"pycryptodome==3.19.1",
@ -178,10 +175,10 @@ dev = [
# Required for storage clients
############################################################
storage = [
"azure-storage-blob==12.13.0",
"azure-storage-blob==12.26.0",
"bce-python-sdk~=0.9.23",
"cos-python-sdk-v5==1.9.30",
"esdk-obs-python==3.24.6.1",
"cos-python-sdk-v5==1.9.38",
"esdk-obs-python==3.25.8",
"google-cloud-storage==2.16.0",
"opendal~=0.46.0",
"oss2==2.18.5",
@ -207,7 +204,7 @@ vdb = [
"couchbase~=4.3.0",
"elasticsearch==8.14.0",
"opensearch-py==2.4.0",
"oracledb==3.0.0",
"oracledb==3.3.0",
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5",
"pymilvus~=2.5.0",

View File

@ -1,19 +1,10 @@
{
"include": ["."],
"exclude": [
".venv",
"tests/",
".venv",
"migrations/",
"core/rag",
"extensions",
"libs",
"controllers/console/datasets",
"controllers/service_api/dataset",
"core/ops",
"core/tools",
"core/model_runtime",
"core/workflow/nodes",
"core/app/app_config/easy_ui_based_app/dataset"
"core/rag"
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
@ -21,6 +12,7 @@
"flask_login",
"opentelemetry.instrumentation.celery",
"opentelemetry.instrumentation.flask",
"opentelemetry.instrumentation.httpx",
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis"
@ -32,7 +24,6 @@
"reportUnknownLambdaType": "hint",
"reportMissingParameterType": "hint",
"reportMissingTypeArgument": "hint",
"reportUnnecessaryContains": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryCast": "hint",
"reportUnnecessaryIsInstance": "hint",
@ -41,4 +32,4 @@
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"
}
}

View File

@ -7,7 +7,7 @@ env =
CHATGLM_API_BASE = http://a.abc.com:11451
CODE_EXECUTION_API_KEY = dify-sandbox
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
CODE_MAX_STRING_LENGTH = 80000
CODE_MAX_STRING_LENGTH = 400000
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_MAX_PACKAGE_SIZE=15728640

View File

@ -6,7 +6,7 @@ import click
import app
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
from tasks import process_tenant_plugin_autoupgrade_check_task as check_task
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
MAX_CONCURRENT_CHECK_TASKS = 20
@ -43,7 +43,7 @@ def check_upgradable_plugin_task():
for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS):
batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS]
for strategy in batch_strategies:
process_tenant_plugin_autoupgrade_check_task.delay(
check_task.process_tenant_plugin_autoupgrade_check_task.delay(
strategy.tenant_id,
strategy.strategy_setting,
strategy.upgrade_time_of_day,
@ -52,7 +52,8 @@ def check_upgradable_plugin_task():
strategy.include_plugins,
)
if batch_interval_time > 0.0001: # if lower than 1ms, skip
# Only sleep if batch_interval_time > 0.0001 AND current batch is not the last one
if batch_interval_time > 0.0001 and i + MAX_CONCURRENT_CHECK_TASKS < total_strategies:
time.sleep(batch_interval_time)
end_at = time.perf_counter()

View File

@ -2,8 +2,6 @@ import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from openai._exceptions import RateLimitError
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
@ -122,8 +120,6 @@ class AppGenerateService:
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
except RateLimitError as e:
raise InvokeRateLimitError(str(e))
except Exception:
rate_limit.exit(request_id)
raise

View File

@ -93,7 +93,7 @@ logger = logging.getLogger(__name__)
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
if user:
# get permitted dataset ids

View File

@ -149,8 +149,7 @@ class RagPipelineTransformService:
file_extensions = node.get("data", {}).get("fileExtensions", [])
if not file_extensions:
return node
file_extensions = [file_extension.lower() for file_extension in file_extensions]
node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
node["data"]["fileExtensions"] = [ext.lower() for ext in file_extensions if ext in DOCUMENT_EXTENSIONS]
return node
def _deal_knowledge_index(

View File

@ -349,14 +349,10 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
credentials: list[ToolProviderCredentialApiEntity] = []
encrypters = {}
for provider in providers:
credential_type = provider.credential_type
if credential_type not in encrypters:
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)[0]
encrypter = encrypters[credential_type]
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,

View File

@ -79,7 +79,6 @@ class WorkflowConverter:
new_app.updated_by = account.id
db.session.add(new_app)
db.session.flush()
db.session.commit()
workflow.app_id = new_app.id
db.session.commit()

View File

@ -1,5 +1,5 @@
import json
import operator
import traceback
import typing
import click
@ -9,38 +9,106 @@ from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
from models.account import TenantPluginAutoUpgradeStrategy
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
CACHE_REDIS_TTL = 60 * 15 # 15 minutes
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
def _get_redis_cache_key(plugin_id: str) -> str:
"""Generate Redis cache key for plugin manifest."""
return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}"
def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]:
"""
Get cached plugin manifest from Redis.
Returns:
- MarketplacePluginDeclaration: if found in cache
- None: if cached as not found (marketplace returned no result)
- False: if not in cache at all
"""
try:
key = _get_redis_cache_key(plugin_id)
cached_data = redis_client.get(key)
if cached_data is None:
return False
cached_json = json.loads(cached_data)
if cached_json is None:
return None
return MarketplacePluginDeclaration.model_validate(cached_json)
except Exception:
return False
def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None:
"""
Cache plugin manifest in Redis.
Args:
plugin_id: The plugin ID
manifest: The manifest to cache, or None if not found in marketplace
"""
try:
key = _get_redis_cache_key(plugin_id)
if manifest is None:
# Cache the fact that this plugin was not found
redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None))
else:
# Cache the manifest data
redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json())
except Exception:
# If Redis fails, continue without caching
# traceback.print_exc()
pass
def marketplace_batch_fetch_plugin_manifests(
plugin_ids_plain_list: list[str],
) -> list[MarketplacePluginDeclaration]:
global cached_plugin_manifests
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
not_included_plugin_ids = [
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
]
if not_included_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
"""Fetch plugin manifests with Redis caching support."""
cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
not_cached_plugin_ids: list[str] = []
# Check Redis cache for each plugin
for plugin_id in plugin_ids_plain_list:
cached_result = _get_cached_manifest(plugin_id)
if cached_result is False:
# Not in cache, need to fetch
not_cached_plugin_ids.append(plugin_id)
else:
# Either found manifest or cached as None (not found in marketplace)
# At this point, cached_result is either MarketplacePluginDeclaration or None
if isinstance(cached_result, bool):
# This should never happen due to the if condition above, but for type safety
continue
cached_manifests[plugin_id] = cached_result
# Fetch uncached plugins from marketplace
if not_cached_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids)
# Cache the fetched manifests
for manifest in manifests:
cached_plugin_manifests[manifest.plugin_id] = manifest
cached_manifests[manifest.plugin_id] = manifest
_set_cached_manifest(manifest.plugin_id, manifest)
if (
len(manifests) == 0
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
for plugin_id in not_included_plugin_ids:
cached_plugin_manifests[plugin_id] = None
# Cache plugins that were not found in marketplace
fetched_plugin_ids = {manifest.plugin_id for manifest in manifests}
for plugin_id in not_cached_plugin_ids:
if plugin_id not in fetched_plugin_ids:
cached_manifests[plugin_id] = None
_set_cached_manifest(plugin_id, None)
# Build result list from cached manifests
result: list[MarketplacePluginDeclaration] = []
for plugin_id in plugin_ids_plain_list:
final_manifest = cached_plugin_manifests.get(plugin_id)
if final_manifest is not None:
result.append(final_manifest)
cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id)
if cached_manifest is not None:
result.append(cached_manifest)
return result
@ -157,10 +225,10 @@ def process_tenant_plugin_autoupgrade_check_task(
)
except Exception as e:
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
traceback.print_exc()
# traceback.print_exc()
break
except Exception as e:
click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red"))
traceback.print_exc()
# traceback.print_exc()
return

View File

@ -29,23 +29,10 @@ def priority_rag_pipeline_run_task(
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
Async Run rag pipeline task using high priority queue.
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
:param tenant_id: Tenant ID for the pipeline execution
"""
# run with threading, thread pool size is 10

View File

@ -30,23 +30,10 @@ def rag_pipeline_run_task(
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
Async Run rag pipeline task using regular priority queue.
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
:param tenant_id: Tenant ID for the pipeline execution
"""
# run with threading, thread pool size is 10

View File

@ -5,15 +5,10 @@ These tasks provide asynchronous storage capabilities for workflow execution dat
improving performance by offloading storage operations to background workers.
"""
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
_logger = logging.getLogger(__name__)
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService

View File

@ -0,0 +1,316 @@
app:
description: 'This chatflow receives a sys.query, writes it into the `answer` variable,
and then outputs the `answer` variable.
`answer` is a conversation variable with a blank default value; it will be updated
in an iteration node.
if this chatflow works correctly, it will output the `sys.query` as the same.'
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: update-conversation-variable-in-iteration
use_icon_as_answer_icon: false
dependencies: []
kind: app
version: 0.4.0
workflow:
conversation_variables:
- description: ''
id: c30af82d-b2ec-417d-a861-4dd78584faa4
name: answer
selector:
- conversation
- answer
value: ''
value_type: string
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- local_file
- remote_url
enabled: false
fileUploadConfig:
audio_file_size_limit: 50
batch_count_limit: 5
file_size_limit: 15
image_file_size_limit: 10
video_file_size_limit: 100
workflow_file_upload_limit: 10
image:
enabled: false
number_limits: 3
transfer_methods:
- local_file
- remote_url
number_limits: 3
opening_statement: ''
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: start
targetType: code
id: 1759032354471-source-1759032363865-target
source: '1759032354471'
sourceHandle: source
target: '1759032363865'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: code
targetType: iteration
id: 1759032363865-source-1759032379989-target
source: '1759032363865'
sourceHandle: source
target: '1759032379989'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: true
isInLoop: false
iteration_id: '1759032379989'
sourceType: iteration-start
targetType: assigner
id: 1759032379989start-source-1759032394460-target
source: 1759032379989start
sourceHandle: source
target: '1759032394460'
targetHandle: target
type: custom
zIndex: 1002
- data:
isInIteration: false
isInLoop: false
sourceType: iteration
targetType: answer
id: 1759032379989-source-1759032410331-target
source: '1759032379989'
sourceHandle: source
target: '1759032410331'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: true
isInLoop: false
iteration_id: '1759032379989'
sourceType: assigner
targetType: code
id: 1759032394460-source-1759032476318-target
source: '1759032394460'
sourceHandle: source
target: '1759032476318'
targetHandle: target
type: custom
zIndex: 1002
nodes:
- data:
selected: false
title: Start
type: start
variables: []
height: 52
id: '1759032354471'
position:
x: 30
y: 302
positionAbsolute:
x: 30
y: 302
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
code: "\ndef main():\n return {\n \"result\": [1],\n }\n"
code_language: python3
outputs:
result:
children: null
type: array[number]
selected: false
title: Code
type: code
variables: []
height: 52
id: '1759032363865'
position:
x: 332
y: 302
positionAbsolute:
x: 332
y: 302
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
error_handle_mode: terminated
height: 204
is_parallel: false
iterator_input_type: array[number]
iterator_selector:
- '1759032363865'
- result
output_selector:
- '1759032476318'
- result
output_type: array[string]
parallel_nums: 10
selected: false
start_node_id: 1759032379989start
title: Iteration
type: iteration
width: 808
height: 204
id: '1759032379989'
position:
x: 634
y: 302
positionAbsolute:
x: 634
y: 302
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 808
zIndex: 1
- data:
desc: ''
isInIteration: true
selected: false
title: ''
type: iteration-start
draggable: false
height: 48
id: 1759032379989start
parentId: '1759032379989'
position:
x: 60
y: 78
positionAbsolute:
x: 694
y: 380
selectable: false
sourcePosition: right
targetPosition: left
type: custom-iteration-start
width: 44
zIndex: 1002
- data:
isInIteration: true
isInLoop: false
items:
- input_type: variable
operation: over-write
value:
- sys
- query
variable_selector:
- conversation
- answer
write_mode: over-write
iteration_id: '1759032379989'
selected: false
title: Variable Assigner
type: assigner
version: '2'
height: 84
id: '1759032394460'
parentId: '1759032379989'
position:
x: 204
y: 60
positionAbsolute:
x: 838
y: 362
sourcePosition: right
targetPosition: left
type: custom
width: 242
zIndex: 1002
- data:
answer: '{{#conversation.answer#}}'
selected: false
title: Answer
type: answer
variables: []
height: 104
id: '1759032410331'
position:
x: 1502
y: 302
positionAbsolute:
x: 1502
y: 302
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
code: "\ndef main():\n return {\n \"result\": '',\n }\n"
code_language: python3
isInIteration: true
isInLoop: false
iteration_id: '1759032379989'
outputs:
result:
children: null
type: string
selected: false
title: Code 2
type: code
variables: []
height: 52
id: '1759032476318'
parentId: '1759032379989'
position:
x: 506
y: 76
positionAbsolute:
x: 1140
y: 378
sourcePosition: right
targetPosition: left
type: custom
width: 242
zIndex: 1002
viewport:
x: 120.39999999999998
y: 85.20000000000005
zoom: 0.7
rag_pipeline_variables: []

View File

@ -11,8 +11,8 @@ from controllers.console.app import completion as completion_api
from controllers.console.app import message as message_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -31,9 +31,8 @@ class TestChatMessageApiPermissions:
return app
@pytest.fixture
def mock_account(self):
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
@ -42,12 +41,24 @@ class TestChatMessageApiPermissions:
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
mock_session_instance = mock.Mock()
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
mock_scalars_result = mock.Mock()
mock_scalars_result.one.return_value = tenant
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
mock_session_context = mock.Mock()
mock_session_context.__enter__.return_value = mock_session_instance
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
account.current_tenant = tenant
return account
@pytest.mark.parametrize(

View File

@ -18,124 +18,87 @@ class TestAppDescriptionValidationUnit:
"""Unit tests for description validation function"""
def test_validate_description_length_function(self):
"""Test the _validate_description_length function directly"""
from controllers.console.app.app import _validate_description_length
"""Test the validate_description_length function directly"""
from libs.validators import validate_description_length
# Test valid descriptions
assert _validate_description_length("") == ""
assert _validate_description_length("x" * 400) == "x" * 400
assert _validate_description_length(None) is None
assert validate_description_length("") == ""
assert validate_description_length("x" * 400) == "x" * 400
assert validate_description_length(None) is None
# Test invalid descriptions
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 401)
validate_description_length("x" * 401)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 500)
validate_description_length("x" * 500)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 1000)
validate_description_length("x" * 1000)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_validation_consistency_with_dataset(self):
"""Test that App and Dataset validation functions are consistent"""
from controllers.console.app.app import _validate_description_length as app_validate
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
# Test same valid inputs
valid_desc = "x" * 400
assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc)
assert app_validate("") == dataset_validate("") == service_dataset_validate("")
assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None)
# Test same invalid inputs produce same error
invalid_desc = "x" * 401
app_error = None
dataset_error = None
service_dataset_error = None
try:
app_validate(invalid_desc)
except ValueError as e:
app_error = str(e)
try:
dataset_validate(invalid_desc)
except ValueError as e:
dataset_error = str(e)
try:
service_dataset_validate(invalid_desc)
except ValueError as e:
service_dataset_error = str(e)
assert app_error == dataset_error == service_dataset_error
assert app_error == "Description cannot exceed 400 characters."
def test_boundary_values(self):
"""Test boundary values for description validation"""
from controllers.console.app.app import _validate_description_length
from libs.validators import validate_description_length
# Test exact boundary
exactly_400 = "x" * 400
assert _validate_description_length(exactly_400) == exactly_400
assert validate_description_length(exactly_400) == exactly_400
# Test just over boundary
just_over_400 = "x" * 401
with pytest.raises(ValueError):
_validate_description_length(just_over_400)
validate_description_length(just_over_400)
# Test just under boundary
just_under_400 = "x" * 399
assert _validate_description_length(just_under_400) == just_under_400
assert validate_description_length(just_under_400) == just_under_400
def test_edge_cases(self):
"""Test edge cases for description validation"""
from controllers.console.app.app import _validate_description_length
from libs.validators import validate_description_length
# Test None input
assert _validate_description_length(None) is None
assert validate_description_length(None) is None
# Test empty string
assert _validate_description_length("") == ""
assert validate_description_length("") == ""
# Test single character
assert _validate_description_length("a") == "a"
assert validate_description_length("a") == "a"
# Test unicode characters
unicode_desc = "测试" * 200 # 400 characters in Chinese
assert _validate_description_length(unicode_desc) == unicode_desc
assert validate_description_length(unicode_desc) == unicode_desc
# Test unicode over limit
unicode_over = "测试" * 201 # 402 characters
with pytest.raises(ValueError):
_validate_description_length(unicode_over)
validate_description_length(unicode_over)
def test_whitespace_handling(self):
"""Test how validation handles whitespace"""
from controllers.console.app.app import _validate_description_length
from libs.validators import validate_description_length
# Test description with spaces
spaces_400 = " " * 400
assert _validate_description_length(spaces_400) == spaces_400
assert validate_description_length(spaces_400) == spaces_400
# Test description with spaces over limit
spaces_401 = " " * 401
with pytest.raises(ValueError):
_validate_description_length(spaces_401)
validate_description_length(spaces_401)
# Test mixed content
mixed_400 = "a" * 200 + " " * 200
assert _validate_description_length(mixed_400) == mixed_400
assert validate_description_length(mixed_400) == mixed_400
# Test mixed over limit
mixed_401 = "a" * 200 + " " * 201
with pytest.raises(ValueError):
_validate_description_length(mixed_401)
validate_description_length(mixed_401)
if __name__ == "__main__":

View File

@ -9,8 +9,8 @@ from flask.testing import FlaskClient
from controllers.console.app import model_config as model_config_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.model import AppMode
from services.app_model_config_service import AppModelConfigService
@ -30,9 +30,8 @@ class TestModelConfigResourcePermissions:
return app
@pytest.fixture
def mock_account(self):
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
@ -41,12 +40,24 @@ class TestModelConfigResourcePermissions:
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
mock_session_instance = mock.Mock()
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
mock_scalars_result = mock.Mock()
mock_scalars_result.one.return_value = tenant
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
mock_session_context = mock.Mock()
mock_session_context.__enter__.return_value = mock_session_instance
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
account.current_tenant = tenant
return account
@pytest.mark.parametrize(

View File

@ -1,9 +1,9 @@
import time
import uuid
from os import getenv
import pytest
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
@ -15,7 +15,7 @@ from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
def init_code_node(code_config: dict):

View File

@ -18,6 +18,7 @@ from flask.testing import FlaskClient
from sqlalchemy import Engine, text
from sqlalchemy.orm import Session
from testcontainers.core.container import DockerContainer
from testcontainers.core.network import Network
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.postgres import PostgresContainer
from testcontainers.redis import RedisContainer
@ -41,6 +42,7 @@ class DifyTestContainers:
def __init__(self):
"""Initialize container management with default configurations."""
self.network: Network | None = None
self.postgres: PostgresContainer | None = None
self.redis: RedisContainer | None = None
self.dify_sandbox: DockerContainer | None = None
@ -62,12 +64,18 @@ class DifyTestContainers:
logger.info("Starting test containers for Dify integration tests...")
# Create Docker network for container communication
logger.info("Creating Docker network for container communication...")
self.network = Network()
self.network.create()
logger.info("Docker network created successfully with name: %s", self.network.name)
# Start PostgreSQL container for main application database
# PostgreSQL is used for storing user data, workflows, and application state
logger.info("Initializing PostgreSQL container...")
self.postgres = PostgresContainer(
image="postgres:14-alpine",
)
).with_network(self.network)
self.postgres.start()
db_host = self.postgres.get_container_host_ip()
db_port = self.postgres.get_exposed_port(5432)
@ -137,7 +145,7 @@ class DifyTestContainers:
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
logger.info("Initializing Redis container...")
self.redis = RedisContainer(image="redis:6-alpine", port=6379)
self.redis = RedisContainer(image="redis:6-alpine", port=6379).with_network(self.network)
self.redis.start()
redis_host = self.redis.get_container_host_ip()
redis_port = self.redis.get_exposed_port(6379)
@ -153,7 +161,7 @@ class DifyTestContainers:
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",
@ -173,22 +181,28 @@ class DifyTestContainers:
# Start Dify Plugin Daemon container for plugin management
# Dify Plugin Daemon provides plugin lifecycle management and execution
logger.info("Initializing Dify Plugin Daemon container...")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local").with_network(
self.network
)
self.dify_plugin_daemon.with_exposed_ports(5002)
# Get container internal network addresses
postgres_container_name = self.postgres.get_wrapped_container().name
redis_container_name = self.redis.get_wrapped_container().name
self.dify_plugin_daemon.env = {
"DB_HOST": db_host,
"DB_PORT": str(db_port),
"DB_HOST": postgres_container_name, # Use container name for internal network communication
"DB_PORT": "5432", # Use internal port
"DB_USERNAME": self.postgres.username,
"DB_PASSWORD": self.postgres.password,
"DB_DATABASE": "dify_plugin",
"REDIS_HOST": redis_host,
"REDIS_PORT": str(redis_port),
"REDIS_HOST": redis_container_name, # Use container name for internal network communication
"REDIS_PORT": "6379", # Use internal port
"REDIS_PASSWORD": "",
"SERVER_PORT": "5002",
"SERVER_KEY": "test_plugin_daemon_key",
"MAX_PLUGIN_PACKAGE_SIZE": "52428800",
"PPROF_ENABLED": "false",
"DIFY_INNER_API_URL": f"http://{db_host}:5001",
"DIFY_INNER_API_URL": f"http://{postgres_container_name}:5001",
"DIFY_INNER_API_KEY": "test_inner_api_key",
"PLUGIN_REMOTE_INSTALLING_HOST": "0.0.0.0",
"PLUGIN_REMOTE_INSTALLING_PORT": "5003",
@ -253,6 +267,15 @@ class DifyTestContainers:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
# Stop and remove the network
if self.network:
try:
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
except Exception as e:
logger.warning("Failed to remove Docker network: %s", e)
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from openai._exceptions import RateLimitError
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import EndUser
@ -484,36 +483,6 @@ class TestAppGenerateService:
# Verify error message
assert "Rate limit exceeded" in str(exc_info.value)
def test_generate_with_rate_limit_error_from_openai(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation when OpenAI rate limit error occurs.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Setup completion generator to raise RateLimitError
mock_response = MagicMock()
mock_response.request = MagicMock()
mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = RateLimitError(
"Rate limit exceeded", response=mock_response, body=None
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect rate limit error
with pytest.raises(InvokeRateLimitError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert "Rate limit exceeded" in str(exc_info.value)
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with invalid app mode.

View File

@ -784,133 +784,6 @@ class TestCleanDatasetTask:
print(f"Total cleanup time: {cleanup_duration:.3f} seconds")
print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds")
def test_clean_dataset_task_concurrent_cleanup_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test dataset cleanup with concurrent cleanup scenarios and race conditions.
This test verifies that the task can properly:
1. Handle multiple cleanup operations on the same dataset
2. Prevent data corruption during concurrent access
3. Maintain data consistency across multiple cleanup attempts
4. Handle race conditions gracefully
5. Ensure idempotent cleanup operations
"""
# Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account, tenant)
document = self._create_test_document(db_session_with_containers, account, tenant, dataset)
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
# Update document with file reference
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
# Save IDs for verification
dataset_id = dataset.id
tenant_id = tenant.id
upload_file_id = upload_file.id
# Mock storage to simulate slow operations
mock_storage = mock_external_service_dependencies["storage"]
original_delete = mock_storage.delete
def slow_delete(key):
import time
time.sleep(0.1) # Simulate slow storage operation
return original_delete(key)
mock_storage.delete.side_effect = slow_delete
# Execute multiple cleanup operations concurrently
import threading
cleanup_results = []
cleanup_errors = []
def run_cleanup():
try:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid.uuid4()),
doc_form="paragraph_index",
)
cleanup_results.append("success")
except Exception as e:
cleanup_errors.append(str(e))
# Start multiple cleanup threads
threads = []
for i in range(3):
thread = threading.Thread(target=run_cleanup)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify results
# Check that all documents were deleted (only once)
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset_id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted (only once)
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset_id).all()
assert len(remaining_segments) == 0
# Check that upload file was deleted (only once)
# Note: In concurrent scenarios, the first thread deletes documents and segments,
# subsequent threads may not find the related data to clean up upload files
# This demonstrates the idempotent nature of the cleanup process
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
# The upload file should be deleted by the first successful cleanup operation
# However, in concurrent scenarios, this may not always happen due to race conditions
# This test demonstrates the idempotent nature of the cleanup process
if len(remaining_files) > 0:
print(f"Warning: Upload file {upload_file_id} was not deleted in concurrent scenario")
print("This is expected behavior demonstrating the idempotent nature of cleanup")
# We don't assert here as the behavior depends on timing and race conditions
# Verify that storage.delete was called (may be called multiple times in concurrent scenarios)
# In concurrent scenarios, storage operations may be called multiple times due to race conditions
assert mock_storage.delete.call_count > 0
# Verify that index processor was called (may be called multiple times in concurrent scenarios)
mock_index_processor = mock_external_service_dependencies["index_processor"]
assert mock_index_processor.clean.call_count > 0
# Check cleanup results
assert len(cleanup_results) == 3, "All cleanup operations should complete"
assert len(cleanup_errors) == 0, "No cleanup errors should occur"
# Verify idempotency by running cleanup again on the same dataset
# This should not perform any additional operations since data is already cleaned
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid.uuid4()),
doc_form="paragraph_index",
)
# Verify that no additional storage operations were performed
# Note: In concurrent scenarios, the exact count may vary due to race conditions
print(f"Final storage delete calls: {mock_storage.delete.call_count}")
print(f"Final index processor calls: {mock_index_processor.clean.call_count}")
print("Note: Multiple calls in concurrent scenarios are expected due to race conditions")
def test_clean_dataset_task_storage_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -0,0 +1,450 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
class TestEnableSegmentsToIndexTask:
"""Integration tests for enable_segments_to_index_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
):
# Setup mock index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
yield {
"index_processor_factory": mock_index_processor_factory,
"index_processor": mock_processor,
}
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test dataset and document for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (dataset, document) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create document
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
# Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset)
return dataset, document
def _create_test_segments(
self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed"
):
"""
Helper method to create test document segments.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
document: Document instance
dataset: Dataset instance
count: Number of segments to create
enabled: Whether segments should be enabled
status: Status of the segments
Returns:
list: List of created DocumentSegment instances
"""
fake = Faker()
segments = []
for i in range(count):
text = fake.text(max_nb_chars=200)
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=text,
word_count=len(text.split()),
tokens=len(text.split()) * 2,
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=enabled,
status=status,
created_by=document.created_by,
)
db.session.add(segment)
segments.append(segment)
db.session.commit()
return segments
def test_enable_segments_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test segments indexing with different index types.
This test verifies:
- Proper handling of different index types
- Index processor factory integration
- Document processing with various configurations
- Redis cache key deletion
"""
# Arrange: Create test data with different index type
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache keys
segment_ids = [segment.id for segment in segments]
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Act: Execute the task
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3
# Verify Redis cache keys were deleted
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
This test verifies:
- Proper error handling for missing datasets
- Early return without processing
- Database session cleanup
- No unnecessary index processor calls
"""
# Arrange: Use non-existent dataset ID
fake = Faker()
non_existent_dataset_id = fake.uuid4()
non_existent_document_id = fake.uuid4()
segment_ids = [fake.uuid4()]
# Act: Execute the task with non-existent dataset
enable_segments_to_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id)
# Assert: Verify no processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent document.
This test verifies:
- Proper error handling for missing documents
- Early return without processing
- Database session cleanup
- No unnecessary index processor calls
"""
# Arrange: Create dataset but use non-existent document ID
dataset, _ = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
fake = Faker()
non_existent_document_id = fake.uuid4()
segment_ids = [fake.uuid4()]
# Act: Execute the task with non-existent document
enable_segments_to_index_task(segment_ids, dataset.id, non_existent_document_id)
# Assert: Verify no processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_invalid_document_status(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of document with invalid status.
This test verifies:
- Early return when document is disabled, archived, or not completed
- No index processing for documents not ready for indexing
- Proper database session cleanup
- No unnecessary external service calls
"""
# Arrange: Create test data with invalid document status
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Test different invalid statuses
invalid_statuses = [
("disabled", {"enabled": False}),
("archived", {"archived": True}),
("not_completed", {"indexing_status": "processing"}),
]
for _, status_attrs in invalid_statuses:
# Reset document status
document.enabled = True
document.archived = False
document.indexing_status = "completed"
db.session.commit()
# Set invalid status
for attr, value in status_attrs.items():
setattr(document, attr, value)
db.session.commit()
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
segment_ids = [segment.id for segment in segments]
# Act: Execute the task
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify no processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
# Clean up segments for next iteration
for segment in segments:
db.session.delete(segment)
db.session.commit()
def test_enable_segments_to_index_segments_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling when no segments are found.
This test verifies:
- Proper handling when segments don't exist
- Early return without processing
- Database session cleanup
- Index processor is created but load is not called
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Use non-existent segment IDs
fake = Faker()
non_existent_segment_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent segments
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
# Assert: Verify index processor was created but load was not called
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test segments indexing with parent-child structure.
This test verifies:
- Proper handling of PARENT_CHILD_INDEX type
- Child document creation from segments
- Correct document structure for parent-child indexing
- Index processor receives properly structured documents
- Redis cache key deletion
"""
# Arrange: Create test data with parent-child index type
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
# Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache keys
segment_ids = [segment.id for segment in segments]
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the get_child_chunks method for each segment
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
# Setup mock to return child chunks for each segment
mock_child_chunks = []
for i in range(2): # Each segment has 2 child chunks
mock_child = MagicMock()
mock_child.content = f"child_content_{i}"
mock_child.index_node_id = f"child_node_{i}"
mock_child.index_node_hash = f"child_hash_{i}"
mock_child_chunks.append(mock_child)
mock_get_child_chunks.return_value = mock_child_chunks
# Act: Execute the task
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # 3 segments
# Verify each document has children
for doc in documents:
assert hasattr(doc, "children")
assert len(doc.children) == 2 # Each document has 2 children
# Verify Redis cache keys were deleted
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test general exception handling during indexing process.
This test verifies:
- Exceptions are properly caught and handled
- Segment status is set to error
- Segments are disabled
- Error information is recorded
- Redis cache is still cleared
- Database session is properly closed
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache keys
segment_ids = [segment.id for segment in segments]
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the index processor to raise an exception
mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Index processing failed")
# Act: Execute the task
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify error handling
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is False
assert segment.status == "error"
assert segment.error is not None
assert "Index processing failed" in segment.error
assert segment.disabled_at is not None
# Verify Redis cache keys were still cleared despite error
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
assert redis_client.exists(indexing_cache_key) == 0

View File

@ -0,0 +1,242 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from extensions.ext_database import db
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
class TestMailAccountDeletionTask:
"""Integration tests for mail account deletion tasks using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_account_deletion_task.mail") as mock_mail,
patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
yield {
"mail": mock_mail,
"get_email_service": mock_get_email_service,
"email_service": mock_email_service,
}
def _create_test_account(self, db_session_with_containers):
"""
Helper method to create a test account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
Account: Created account instance
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
return account
def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful account deletion success email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls
- Template context is properly formatted
- Email type is correctly specified
"""
# Arrange: Create test data
account = self._create_test_account(db_session_with_containers)
test_email = account.email
test_language = "en-US"
# Act: Execute the task
send_deletion_success_task(test_email, test_language)
# Assert: Verify the expected outcomes
# Verify mail service was checked
mock_external_service_dependencies["mail"].is_inited.assert_called_once()
# Verify email service was retrieved
mock_external_service_dependencies["get_email_service"].assert_called_once()
# Verify email was sent with correct parameters
mock_external_service_dependencies["email_service"].send_email.assert_called_once_with(
email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
language_code=test_language,
to=test_email,
template_context={
"to": test_email,
"email": test_email,
},
)
def test_send_deletion_success_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test account deletion success email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Setup mail service to return not initialized
mock_external_service_dependencies["mail"].is_inited.return_value = False
account = self._create_test_account(db_session_with_containers)
test_email = account.email
# Act: Execute the task
send_deletion_success_task(test_email)
# Assert: Verify no email service calls were made
mock_external_service_dependencies["get_email_service"].assert_not_called()
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_deletion_success_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test account deletion success email when email service raises exception.
This test verifies:
- Exception is properly caught and logged
- Task completes without raising exception
- Error logging is recorded
"""
# Arrange: Setup email service to raise exception
mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed")
account = self._create_test_account(db_session_with_containers)
test_email = account.email
# Act: Execute the task (should not raise exception)
send_deletion_success_task(test_email)
# Assert: Verify email service was called but exception was handled
mock_external_service_dependencies["email_service"].send_email.assert_called_once()
def test_send_account_deletion_verification_code_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful account deletion verification code email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls
- Template context includes verification code
- Email type is correctly specified
"""
# Arrange: Create test data
account = self._create_test_account(db_session_with_containers)
test_email = account.email
test_code = "123456"
test_language = "en-US"
# Act: Execute the task
send_account_deletion_verification_code(test_email, test_code, test_language)
# Assert: Verify the expected outcomes
# Verify mail service was checked
mock_external_service_dependencies["mail"].is_inited.assert_called_once()
# Verify email service was retrieved
mock_external_service_dependencies["get_email_service"].assert_called_once()
# Verify email was sent with correct parameters
mock_external_service_dependencies["email_service"].send_email.assert_called_once_with(
email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
language_code=test_language,
to=test_email,
template_context={
"to": test_email,
"code": test_code,
},
)
def test_send_account_deletion_verification_code_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test account deletion verification code email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Setup mail service to return not initialized
mock_external_service_dependencies["mail"].is_inited.return_value = False
account = self._create_test_account(db_session_with_containers)
test_email = account.email
test_code = "123456"
# Act: Execute the task
send_account_deletion_verification_code(test_email, test_code)
# Assert: Verify no email service calls were made
mock_external_service_dependencies["get_email_service"].assert_not_called()
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_account_deletion_verification_code_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test account deletion verification code email when email service raises exception.
This test verifies:
- Exception is properly caught and logged
- Task completes without raising exception
- Error logging is recorded
"""
# Arrange: Setup email service to raise exception
mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed")
account = self._create_test_account(db_session_with_containers)
test_email = account.email
test_code = "123456"
# Act: Execute the task (should not raise exception)
send_account_deletion_verification_code(test_email, test_code)
# Assert: Verify email service was called but exception was handled
mock_external_service_dependencies["email_service"].send_email.assert_called_once()

Some files were not shown because too many files have changed in this diff Show More