mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/memory-orchestration-be
This commit is contained in:
commit
c367f80ec5
|
|
@ -1,4 +1,4 @@
|
|||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
|
|
|||
1
Makefile
1
Makefile
|
|
@ -26,7 +26,6 @@ prepare-web:
|
|||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
|
|
|
|||
24
README.md
24
README.md
|
|
@ -40,18 +40,18 @@
|
|||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
|
|
|
|||
|
|
@ -427,8 +427,8 @@ CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
|||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
|
|
|
|||
|
|
@ -80,10 +80,10 @@
|
|||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
|
|
@ -582,6 +582,11 @@ class WorkflowConfig(BaseSettings):
|
|||
default=200 * 1024,
|
||||
)
|
||||
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||
description="Maximum number of characters allowed in Template Transform node output",
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from configs import dify_config
|
||||
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||
|
|
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
|||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.append("ppt")
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = [
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
|
|
@ -37,5 +53,5 @@ else:
|
|||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
|||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
|
|
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
|
|||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
|
|
@ -138,7 +133,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
|
|
@ -219,7 +214,7 @@ class AppApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
@ -297,7 +292,7 @@ class AppCopyApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import flask_restx
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
|
|
@ -30,24 +31,20 @@ from fields.app_fields import related_app_list
|
|||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
|
|
@ -92,7 +89,7 @@ class DatasetListApi(Resource):
|
|||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
|
|
@ -147,7 +144,7 @@ class DatasetListApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -192,7 +189,7 @@ class DatasetListApi(Resource):
|
|||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
|
|
@ -224,7 +221,7 @@ class DatasetApi(Resource):
|
|||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
|
|
@ -288,7 +285,7 @@ class DatasetApi(Resource):
|
|||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
|
|
@ -369,7 +366,7 @@ class DatasetApi(Resource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
|
|
@ -688,7 +685,7 @@ class DatasetApiKeyApi(Resource):
|
|||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restx.abort(
|
||||
api.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
|
|
@ -733,7 +730,7 @@ class DatasetApiDeleteApi(Resource):
|
|||
)
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(404, message="API key not found")
|
||||
api.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ from fields.document_fields import (
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
|
@ -418,7 +419,9 @@ class DatasetInitApi(Resource):
|
|||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=cast(Account, current_user),
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -452,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
|
||||
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
||||
|
||||
|
|
@ -514,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
|
|
@ -753,7 +756,7 @@ class DocumentApi(DocumentResource):
|
|||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
@ -1073,7 +1076,9 @@ class DocumentRenameApi(DocumentResource):
|
|||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -392,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
|
|
@ -468,7 +473,8 @@ class ChildChunkAddApi(Resource):
|
|||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
|
@ -557,7 +563,8 @@ class ChildChunkAddApi(Resource):
|
|||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
|
@ -674,9 +681,8 @@ class ChildChunkUpdateApi(Resource):
|
|||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(
|
||||
args.get("content"), child_chunk, segment, document, dataset
|
||||
)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
|
|
@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
|
|
@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -20,6 +21,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
|
@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
|
|||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class DatasetMetadataApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
|
|
@ -70,7 +71,7 @@ class DatasetMetadataApi(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
|
@ -11,6 +10,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
|
@ -76,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -133,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
|||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
return import_info, 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
|
|
|
|||
|
|
@ -1,24 +1,22 @@
|
|||
import logging
|
||||
from typing import Any, NoReturn
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -104,13 +76,14 @@ def _api_prerequisite(f):
|
|||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
|
|
@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
|||
return None
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
|
|
@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
|||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
|
@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
|
|||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
|
|
@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
|||
return draft_vars
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
|
|
@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
|
|||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
|
|
@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
|||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
|||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
|
|||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -66,6 +67,7 @@ class RagPipelineImportApi(Resource):
|
|||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -90,6 +92,7 @@ class RagPipelineImportConfirmApi(Resource):
|
|||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource):
|
|||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
|
|
@ -50,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -147,6 +148,7 @@ class DraftRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -181,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -215,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -249,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -369,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -411,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -453,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -486,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -504,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -559,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -577,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -608,6 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -656,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -713,6 +727,7 @@ class RagPipelineByIdApi(Resource):
|
|||
return workflow
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -738,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -763,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -788,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -814,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -835,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
|
||||
class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -853,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
|||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -876,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
|||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/datasource-plugins")
|
||||
class DatasourceListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -891,6 +913,7 @@ class DatasourceListApi(Resource):
|
|||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class RagPipelineWorkflowLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -912,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
|||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
|
||||
class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -929,6 +953,7 @@ class RagPipelineTransformApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -958,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/recommended-plugins")
|
||||
class RagPipelineRecommendedPluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -966,114 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
||||
return recommended_plugins
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTaskStopApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDraftNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelinePublishedDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PublishedRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigsApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineByIdApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunDetailApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunNodeExecutionListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceListApi,
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTransformApi,
|
||||
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDatasourceVariableApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineRecommendedPluginApi,
|
||||
"/rag/pipelines/recommended-plugins",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from flask_login import current_user
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
|
|
@ -17,6 +17,7 @@ def _validate_name(name):
|
|||
return name
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -52,6 +53,7 @@ class TagListApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
|
|||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
|
|||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, "/tags")
|
||||
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
|
||||
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
|
||||
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
|
|
@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
|
|||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -31,12 +32,6 @@ def _validate_name(name):
|
|||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = reqparse.RequestParser()
|
||||
dataset_create_parser.add_argument(
|
||||
|
|
@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
|
|||
)
|
||||
dataset_create_parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
|
|||
type=_validate_name,
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"description", location="json", store_missing=False, type=_validate_description_length
|
||||
"description", location="json", store_missing=False, type=validate_description_length
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"indexing_technique",
|
||||
|
|
@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
|
|||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
|
|||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
|
|
@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
|
|||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
|
|
@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
|
|||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
||||
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||
if retrieval_model_dict:
|
||||
retrieval_model_dict["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
|
@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
|
|||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
|
|
@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
|
|
@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
return 204
|
||||
|
||||
|
|
|
|||
|
|
@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
|
|
@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import uuid
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -74,6 +75,9 @@ class DatasetConfigManager:
|
|||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
|
|
@ -82,18 +86,23 @@ class DatasetConfigManager:
|
|||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
score_threshold_val = dataset_configs.get("score_threshold")
|
||||
reranking_model_val = dataset_configs.get("reranking_model")
|
||||
weights_val = dataset_configs.get("weights")
|
||||
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
|
|
@ -101,22 +110,23 @@ class DatasetConfigManager:
|
|||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold")
|
||||
if dataset_configs.get("score_threshold_enabled", False)
|
||||
top_k=int(dataset_configs.get("top_k", 4)),
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
|
@ -134,18 +144,17 @@ class DatasetConfigManager:
|
|||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {}
|
||||
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
|
@ -166,8 +175,8 @@ class DatasetConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||
config["agent_mode"] = {}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
|
@ -180,19 +189,22 @@ class DatasetConfigManager:
|
|||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
# tools
|
||||
if not config["agent_mode"].get("tools"):
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
# strategy
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER.value,
|
||||
PlanningStrategy.REACT_ROUTER.value,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
# old style, use tool name as key
|
||||
|
|
@ -217,7 +229,7 @@ class DatasetConfigManager:
|
|||
|
||||
has_datasets = True
|
||||
|
||||
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
||||
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
|
|||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
|
|
@ -35,9 +39,8 @@ class AppQueueManager:
|
|||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
|
||||
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
|
|
@ -79,9 +82,21 @@ class AppQueueManager:
|
|||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom):
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
Remove the task belong cache key once listening is finished.
|
||||
"""
|
||||
try:
|
||||
redis_client.delete(self._task_belong_cache_key)
|
||||
except RedisError:
|
||||
logger.exception(
|
||||
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
|
||||
)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ class MessageCycleManager:
|
|||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
|
|
@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
|
|
@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class GPT2Tokenizer:
|
|||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -196,15 +196,15 @@ def jsonable_encoder(
|
|||
return encoder(obj)
|
||||
|
||||
try:
|
||||
data = dict(obj)
|
||||
data = dict(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors: list[Exception] = []
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
data = vars(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors) from e
|
||||
raise ValueError(str(errors)) from e
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from dataclasses import dataclass
|
|||
from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -155,7 +155,10 @@ class OpsTraceManager:
|
|||
if key in tracing_config:
|
||||
if "*" in tracing_config[key]:
|
||||
# If the key contains '*', retain the original value from the current config
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
if current_trace_config:
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
else:
|
||||
new_config[key] = tracing_config[key]
|
||||
else:
|
||||
# Otherwise, encrypt the key
|
||||
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
self,
|
||||
):
|
||||
try:
|
||||
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
|
|
@ -424,7 +425,23 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
call = self.weave_client.create_call(
|
||||
op=run_data.op,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
|
|
@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
||||
exception = Exception(run_data.exception) if run_data.exception else None
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
|
||||
else:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
|
|
|||
|
|
@ -106,7 +106,9 @@ class RetrievalService:
|
|||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
|
|
@ -143,6 +145,40 @@ class RetrievalService:
|
|||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
unique_documents = []
|
||||
seen_doc_ids = set()
|
||||
|
||||
for document in documents:
|
||||
# For dify provider documents, use doc_id for deduplication
|
||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
||||
doc_id = document.metadata["doc_id"]
|
||||
if doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
# If duplicate, keep the one with higher score
|
||||
elif "score" in document.metadata:
|
||||
# Find existing document with same doc_id and compare scores
|
||||
for i, existing_doc in enumerate(unique_documents):
|
||||
if (
|
||||
existing_doc.metadata
|
||||
and existing_doc.metadata.get("doc_id") == doc_id
|
||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
||||
):
|
||||
unique_documents[i] = document
|
||||
break
|
||||
else:
|
||||
# For non-dify documents, use content-based deduplication
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
return unique_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
with Session(db.engine) as session:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
|||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self._run(query)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool):
|
|||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
result = self.retrieval_tool.run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import re
|
|||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
|
|
@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser:
|
|||
if "allOf" in prop_dict:
|
||||
del prop_dict["allOf"]
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
|
|
@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser:
|
|||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
|
@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser:
|
|||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
|
|
@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser:
|
|||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
converted_openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
|
@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser:
|
|||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
converted_openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
|
|
@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser:
|
|||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
if "definitions" in swagger:
|
||||
for name, definition in swagger["definitions"].items():
|
||||
converted_openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
return converted_openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
|
|
|
|||
|
|
@ -191,11 +191,22 @@ class VariablePool(BaseModel):
|
|||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
||||
"""
|
||||
Get a nested attribute from a dictionary-like object.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-like object to search.
|
||||
attr: The key to look up.
|
||||
|
||||
Returns:
|
||||
Segment | None:
|
||||
The corresponding Segment built from the attribute value if the key exists,
|
||||
otherwise None.
|
||||
"""
|
||||
if not isinstance(obj, dict) or attr not in obj:
|
||||
return None
|
||||
return obj.get(attr)
|
||||
return variable_factory.build_segment(obj.get(attr))
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ModelInvokeCompletedEvent(NodeEventBase):
|
|||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
structured_output: dict | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(NodeEventBase):
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class Executor:
|
|||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
|
|
@ -349,11 +349,10 @@ class Executor:
|
|||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"ssl_verify": self.ssl_verify,
|
||||
"follow_redirects": True,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
|
|
|
|||
|
|
@ -165,6 +165,8 @@ class HttpRequestNode(Node):
|
|||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class IfElseNode(Node):
|
|||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing_extensions import TypeIs
|
|||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
|
|
@ -217,6 +219,13 @@ class IterationNode(Node):
|
|||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Sync conversation variables after each iteration completes
|
||||
self._sync_conversation_variables_from_snapshot(
|
||||
self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
|
@ -235,7 +244,10 @@ class IterationNode(Node):
|
|||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
|
|
@ -252,7 +264,7 @@ class IterationNode(Node):
|
|||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
|
@ -264,6 +276,9 @@ class IterationNode(Node):
|
|||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
|
|
@ -288,7 +303,7 @@ class IterationNode(Node):
|
|||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -307,8 +322,17 @@ class IterationNode(Node):
|
|||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
return (
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
|
|
@ -430,6 +454,23 @@ class IterationNode(Node):
|
|||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
current_keys = set(parent_conversations.keys())
|
||||
snapshot_keys = set(snapshot.keys())
|
||||
|
||||
for removed_key in current_keys - snapshot_keys:
|
||||
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
|
||||
|
||||
for name, variable in snapshot.items():
|
||||
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
|
||||
|
||||
def _append_iteration_info_to_event(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
|
|
|
|||
|
|
@ -136,6 +136,11 @@ class KnowledgeIndexNode(Node):
|
|||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
doc_id_value = document.id
|
||||
ds_id_value = dataset.id
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
|
|
@ -161,16 +166,16 @@ class KnowledgeIndexNode(Node):
|
|||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
|
|
@ -182,13 +187,13 @@ class KnowledgeIndexNode(Node):
|
|||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
"document_id": doc_id_value,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
|
|
|
|||
|
|
@ -161,6 +161,8 @@ class ListOperatorNode(Node):
|
|||
elif isinstance(variable, ArrayFileSegment):
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
elif isinstance(condition.value, bool):
|
||||
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class LLMFileSaver(tp.Protocol):
|
|||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
|
@ -56,7 +56,7 @@ class LLMFileSaver(tp.Protocol):
|
|||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from core.model_runtime.entities.llm_entities import (
|
|||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
|
|
@ -134,7 +135,7 @@ class LLMNode(Node):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
|
|
@ -172,6 +173,7 @@ class LLMNode(Node):
|
|||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
result_text = ""
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
|
|
@ -285,6 +287,13 @@ class LLMNode(Node):
|
|||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=event.structured_output)
|
||||
if event.structured_output
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
|
|
@ -1060,7 +1069,7 @@ class LLMNode(Node):
|
|||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
*,
|
||||
invoke_result: LLMResult,
|
||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
|
|
@ -1091,6 +1100,8 @@ class LLMNode(Node):
|
|||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if enabled
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -179,6 +179,6 @@ CHAT_EXAMPLE = [
|
|||
"required": ["food"],
|
||||
},
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
|
||||
},
|
||||
]
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class QuestionClassifierNode(Node):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
|
|
@ -111,9 +111,9 @@ class QuestionClassifierNode(Node):
|
|||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_model=node_data.model,
|
||||
)
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -9,7 +9,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node):
|
||||
|
|
|
|||
|
|
@ -416,4 +416,8 @@ class WorkflowEntry:
|
|||
|
||||
# append variable and value to variable pool
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
# In single run, the input_value is set as the LLM's structured output value within the variable_pool.
|
||||
if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output":
|
||||
input_value = {variable_key_list[1]: input_value}
|
||||
variable_key_list = variable_key_list[0:1]
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ from dify_app import DifyApp
|
|||
|
||||
def init_app(app: DifyApp):
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
def after_request(response): # pyright: ignore[reportUnusedFunction]
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.project.version)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
def health(): # pyright: ignore[reportUnusedFunction]
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
|
||||
status=200,
|
||||
|
|
@ -25,7 +25,7 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
|
||||
@app.route("/threads")
|
||||
def threads():
|
||||
def threads(): # pyright: ignore[reportUnusedFunction]
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ def init_app(app: DifyApp):
|
|||
}
|
||||
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat():
|
||||
def pool_stat(): # pyright: ignore[reportUnusedFunction]
|
||||
from extensions.ext_database import db
|
||||
|
||||
engine = db.engine
|
||||
|
|
|
|||
|
|
@ -145,6 +145,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
||||
beat_schedule["check_upgradable_plugin_task"] = {
|
||||
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
|
||||
"schedule": crontab(minute="*/15"),
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from models.engine import db
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag to avoid duplicate registration of event listener
|
||||
_GEVENT_COMPATIBILITY_SETUP: bool = False
|
||||
_gevent_compatibility_setup: bool = False
|
||||
|
||||
|
||||
def _safe_rollback(connection):
|
||||
|
|
@ -26,14 +26,14 @@ def _safe_rollback(connection):
|
|||
|
||||
|
||||
def _setup_gevent_compatibility():
|
||||
global _GEVENT_COMPATIBILITY_SETUP # pylint: disable=global-statement
|
||||
global _gevent_compatibility_setup # pylint: disable=global-statement
|
||||
|
||||
# Avoid duplicate registration
|
||||
if _GEVENT_COMPATIBILITY_SETUP:
|
||||
if _gevent_compatibility_setup:
|
||||
return
|
||||
|
||||
@event.listens_for(Pool, "reset")
|
||||
def _safe_reset(dbapi_connection, connection_record, reset_state): # pylint: disable=unused-argument
|
||||
def _safe_reset(dbapi_connection, connection_record, reset_state): # pyright: ignore[reportUnusedFunction]
|
||||
if reset_state.terminate_only:
|
||||
return
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def _setup_gevent_compatibility():
|
|||
except (AttributeError, ImportError):
|
||||
_safe_rollback(dbapi_connection)
|
||||
|
||||
_GEVENT_COMPATIBILITY_SETUP = True
|
||||
_gevent_compatibility_setup = True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@ from dify_app import DifyApp
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from events import event_handlers # noqa: F401
|
||||
from events import event_handlers # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ def init_app(app: DifyApp):
|
|||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
|
|
@ -238,6 +239,7 @@ def init_app(app: DifyApp):
|
|||
init_sqlalchemy_instrumentor(app)
|
||||
RedisInstrumentor().instrument()
|
||||
RequestsInstrumentor().instrument()
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
atexit.register(shutdown_tracer)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from dify_app import DifyApp
|
|||
|
||||
def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
|
@ -28,7 +27,6 @@ def init_app(app: DifyApp):
|
|||
HTTPException,
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
openai.APIStatusError,
|
||||
InvokeRateLimitError,
|
||||
parse_error.defaultErrorResponse,
|
||||
],
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ class AliyunOssStorage(BaseStorage):
|
|||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
data: bytes = obj.read()
|
||||
data = obj.read()
|
||||
if not isinstance(data, bytes):
|
||||
return b""
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
|
|
|||
|
|
@ -39,10 +39,10 @@ class AwsS3Storage(BaseStorage):
|
|||
self.client.head_bucket(Bucket=self.bucket_name)
|
||||
except ClientError as e:
|
||||
# if bucket not exists, create it
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
self.client.create_bucket(Bucket=self.bucket_name)
|
||||
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
|
||||
elif e.response["Error"]["Code"] == "403":
|
||||
elif e.response.get("Error", {}).get("Code") == "403":
|
||||
pass
|
||||
else:
|
||||
# other error, raise exception
|
||||
|
|
@ -55,7 +55,7 @@ class AwsS3Storage(BaseStorage):
|
|||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
|
|
@ -66,7 +66,7 @@ class AwsS3Storage(BaseStorage):
|
|||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("file not found")
|
||||
elif "reached max retries" in str(ex):
|
||||
raise ValueError("please do not request the same file too frequently")
|
||||
|
|
|
|||
|
|
@ -27,24 +27,38 @@ class AzureBlobStorage(BaseStorage):
|
|||
self.credential = None
|
||||
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
|
||||
client = self._sync_client()
|
||||
blob = client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data: bytes = blob.download_blob().readall()
|
||||
data = blob.download_blob().readall()
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
|
||||
client = self._sync_client()
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
blob_data = blob.download_blob()
|
||||
yield from blob_data.chunks()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
|
|
@ -53,12 +67,18 @@ class AzureBlobStorage(BaseStorage):
|
|||
blob_data.readinto(my_blob)
|
||||
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
|
||||
def delete(self, filename):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0
|
||||
exists = len(rows) > 0 if rows else False
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
|
|
@ -509,16 +509,17 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
if rows:
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -439,6 +439,11 @@ class VolumePermissionManager:
|
|||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
@property
|
||||
def volume_type(self) -> str | None:
|
||||
"""Get the volume type."""
|
||||
return self._volume_type
|
||||
|
||||
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
|
||||
"""Get permission summary
|
||||
|
||||
|
|
@ -632,13 +637,13 @@ def check_volume_permission(permission_manager: VolumePermissionManager, operati
|
|||
VolumePermissionError: If no permission
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager._volume_type or "unknown",
|
||||
volume_type=permission_manager.volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,12 +35,16 @@ class GoogleCloudStorage(BaseStorage):
|
|||
def load_once(self, filename: str) -> bytes:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
data: bytes = blob.download_as_bytes()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
with blob.open(mode="rb") as blob_stream:
|
||||
while chunk := blob_stream.read(4096):
|
||||
yield chunk
|
||||
|
|
@ -48,6 +52,8 @@ class GoogleCloudStorage(BaseStorage):
|
|||
def download(self, filename, target_filepath):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
blob.download_to_filename(target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class HuaweiObsStorage(BaseStorage):
|
|||
|
||||
def _get_meta(self, filename):
|
||||
res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
|
||||
if res.status < 300:
|
||||
if res and res.status and res.status < 300:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import os
|
|||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import opendal
|
||||
from dotenv import dotenv_values
|
||||
from opendal import Operator
|
||||
from opendal.layers import RetryLayer
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class OpenDALStorage(BaseStorage):
|
|||
root = kwargs.get("root", "storage")
|
||||
Path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
retry_layer = RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
self.op = Operator(scheme=scheme, **kwargs).layer(retry_layer)
|
||||
logger.debug("opendal operator created with scheme %s", scheme)
|
||||
logger.debug("added retry layer to opendal operator")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class OracleOCIStorage(BaseStorage):
|
|||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
|
|
@ -40,7 +40,7 @@ class OracleOCIStorage(BaseStorage):
|
|||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -46,13 +46,13 @@ class SupabaseStorage(BaseStorage):
|
|||
Path(target_filepath).write_bytes(result)
|
||||
|
||||
def exists(self, filename):
|
||||
result = self.client.storage.from_(self.bucket_name).list(filename)
|
||||
if result.count() > 0:
|
||||
result = self.client.storage.from_(self.bucket_name).list(path=filename)
|
||||
if len(result) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.storage.from_(self.bucket_name).remove(filename)
|
||||
self.client.storage.from_(self.bucket_name).remove([filename])
|
||||
|
||||
def bucket_exists(self):
|
||||
buckets = self.client.storage.list_buckets()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,14 @@ class VolcengineTosStorage(BaseStorage):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not dify_config.VOLCENGINE_TOS_ACCESS_KEY:
|
||||
raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_SECRET_KEY:
|
||||
raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_ENDPOINT:
|
||||
raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_REGION:
|
||||
raise ValueError("VOLCENGINE_TOS_REGION is not set")
|
||||
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
|
||||
self.client = tos.TosClientV2(
|
||||
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
|
||||
|
|
@ -20,27 +28,39 @@ class VolcengineTosStorage(BaseStorage):
|
|||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError(f"Expected bytes, got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
response = self.client.get_object(bucket=self.bucket_name, key=filename)
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
res = self.client.head_object(bucket=self.bucket_name, key=filename)
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, filename):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
self.client.delete_object(bucket=self.bucket_name, key=filename)
|
||||
|
|
|
|||
|
|
@ -146,6 +146,8 @@ def build_segment(value: Any, /) -> Segment:
|
|||
# below
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
if isinstance(value, Segment):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
if isinstance(value, bool):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]:
|
||||
"""
|
||||
Convert a list or set of strings to a set containing both lower and upper case versions of each string.
|
||||
|
||||
Args:
|
||||
inputs (list[str] | set[str]): A list or set of strings to be converted.
|
||||
|
||||
Returns:
|
||||
set[str]: A set containing both lower and upper case versions of each string.
|
||||
"""
|
||||
if not inputs:
|
||||
return set()
|
||||
else:
|
||||
return {case for s in inputs if s for case in (s.lower(), s.upper())}
|
||||
|
|
@ -94,7 +94,7 @@ def register_external_error_handlers(api: Api):
|
|||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
status_code = 500
|
||||
data = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||
|
||||
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
|
||||
if not isinstance(data, dict):
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import gmpy2 # type: ignore
|
|||
from Crypto import Random
|
||||
from Crypto.Signature.pss import MGF1
|
||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
from Crypto.Util.py3compat import _copy_bytes, bord
|
||||
from Crypto.Util.py3compat import bord
|
||||
from Crypto.Util.strxor import strxor
|
||||
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ class PKCS1OAepCipher:
|
|||
else:
|
||||
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
|
||||
|
||||
self._label = _copy_bytes(None, None, label)
|
||||
self._label = bytes(label)
|
||||
self._randfunc = randfunc
|
||||
|
||||
def can_encrypt(self):
|
||||
|
|
@ -120,7 +120,7 @@ class PKCS1OAepCipher:
|
|||
# Step 2b
|
||||
ps = b"\x00" * ps_len
|
||||
# Step 2c
|
||||
db = lHash + ps + b"\x01" + _copy_bytes(None, None, message)
|
||||
db = lHash + ps + b"\x01" + bytes(message)
|
||||
# Step 2d
|
||||
ros = self._randfunc(hLen)
|
||||
# Step 2e
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class SendGridClient:
|
|||
|
||||
def send(self, mail: dict):
|
||||
logger.debug("Sending email with SendGrid")
|
||||
|
||||
_to = ""
|
||||
try:
|
||||
_to = mail["to"]
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class SendGridClient:
|
|||
content = Content("text/html", mail["html"])
|
||||
sg_mail = Mail(from_email, to_email, subject, content)
|
||||
mail_json = sg_mail.get()
|
||||
response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable]
|
||||
response = sg.client.mail.send.post(request_body=mail_json) # type: ignore
|
||||
logger.debug(response.status_code)
|
||||
logger.debug(response.body)
|
||||
logger.debug(response.headers)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
def validate_description_length(description: str | None) -> str | None:
|
||||
"""Validate description length."""
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
|
@ -1,11 +1,10 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.9.0"
|
||||
version = "1.9.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"arize-phoenix-otel~=0.9.2",
|
||||
"authlib==1.6.4",
|
||||
"azure-identity==1.16.1",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.35.99",
|
||||
|
|
@ -34,10 +33,8 @@ dependencies = [
|
|||
"json-repair>=0.41.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"mailchimp-transactional~=1.0.50",
|
||||
"markdown~=3.5.1",
|
||||
"numpy~=1.26.4",
|
||||
"openai~=1.61.0",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.7.25",
|
||||
"opentelemetry-api==1.27.0",
|
||||
|
|
@ -49,6 +46,7 @@ dependencies = [
|
|||
"opentelemetry-instrumentation==0.48b0",
|
||||
"opentelemetry-instrumentation-celery==0.48b0",
|
||||
"opentelemetry-instrumentation-flask==0.48b0",
|
||||
"opentelemetry-instrumentation-httpx==0.48b0",
|
||||
"opentelemetry-instrumentation-redis==0.48b0",
|
||||
"opentelemetry-instrumentation-requests==0.48b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
|
||||
|
|
@ -60,7 +58,6 @@ dependencies = [
|
|||
"opentelemetry-semantic-conventions==0.48b0",
|
||||
"opentelemetry-util-http==0.48b0",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"pandoc~=2.4",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.19.1",
|
||||
|
|
@ -178,10 +175,10 @@ dev = [
|
|||
# Required for storage clients
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob==12.13.0",
|
||||
"azure-storage-blob==12.26.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"cos-python-sdk-v5==1.9.30",
|
||||
"esdk-obs-python==3.24.6.1",
|
||||
"cos-python-sdk-v5==1.9.38",
|
||||
"esdk-obs-python==3.25.8",
|
||||
"google-cloud-storage==2.16.0",
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.18.5",
|
||||
|
|
@ -207,7 +204,7 @@ vdb = [
|
|||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
"oracledb==3.0.0",
|
||||
"oracledb==3.3.0",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
|
|
|
|||
|
|
@ -1,19 +1,10 @@
|
|||
{
|
||||
"include": ["."],
|
||||
"exclude": [
|
||||
".venv",
|
||||
"tests/",
|
||||
".venv",
|
||||
"migrations/",
|
||||
"core/rag",
|
||||
"extensions",
|
||||
"libs",
|
||||
"controllers/console/datasets",
|
||||
"controllers/service_api/dataset",
|
||||
"core/ops",
|
||||
"core/tools",
|
||||
"core/model_runtime",
|
||||
"core/workflow/nodes",
|
||||
"core/app/app_config/easy_ui_based_app/dataset"
|
||||
"core/rag"
|
||||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"allowedUntypedLibraries": [
|
||||
|
|
@ -21,6 +12,7 @@
|
|||
"flask_login",
|
||||
"opentelemetry.instrumentation.celery",
|
||||
"opentelemetry.instrumentation.flask",
|
||||
"opentelemetry.instrumentation.httpx",
|
||||
"opentelemetry.instrumentation.requests",
|
||||
"opentelemetry.instrumentation.sqlalchemy",
|
||||
"opentelemetry.instrumentation.redis"
|
||||
|
|
@ -32,7 +24,6 @@
|
|||
"reportUnknownLambdaType": "hint",
|
||||
"reportMissingParameterType": "hint",
|
||||
"reportMissingTypeArgument": "hint",
|
||||
"reportUnnecessaryContains": "hint",
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryCast": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
|
|
@ -41,4 +32,4 @@
|
|||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
}
|
||||
}
|
||||
|
|
@ -7,7 +7,7 @@ env =
|
|||
CHATGLM_API_BASE = http://a.abc.com:11451
|
||||
CODE_EXECUTION_API_KEY = dify-sandbox
|
||||
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
|
||||
CODE_MAX_STRING_LENGTH = 80000
|
||||
CODE_MAX_STRING_LENGTH = 400000
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import click
|
|||
import app
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
|
||||
from tasks import process_tenant_plugin_autoupgrade_check_task as check_task
|
||||
|
||||
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
|
||||
MAX_CONCURRENT_CHECK_TASKS = 20
|
||||
|
|
@ -43,7 +43,7 @@ def check_upgradable_plugin_task():
|
|||
for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS):
|
||||
batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS]
|
||||
for strategy in batch_strategies:
|
||||
process_tenant_plugin_autoupgrade_check_task.delay(
|
||||
check_task.process_tenant_plugin_autoupgrade_check_task.delay(
|
||||
strategy.tenant_id,
|
||||
strategy.strategy_setting,
|
||||
strategy.upgrade_time_of_day,
|
||||
|
|
@ -52,7 +52,8 @@ def check_upgradable_plugin_task():
|
|||
strategy.include_plugins,
|
||||
)
|
||||
|
||||
if batch_interval_time > 0.0001: # if lower than 1ms, skip
|
||||
# Only sleep if batch_interval_time > 0.0001 AND current batch is not the last one
|
||||
if batch_interval_time > 0.0001 and i + MAX_CONCURRENT_CHECK_TASKS < total_strategies:
|
||||
time.sleep(batch_interval_time)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
|
|
@ -122,8 +120,6 @@ class AppGenerateService:
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
except RateLimitError as e:
|
||||
raise InvokeRateLimitError(str(e))
|
||||
except Exception:
|
||||
rate_limit.exit(request_id)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
|
|
|
|||
|
|
@ -149,8 +149,7 @@ class RagPipelineTransformService:
|
|||
file_extensions = node.get("data", {}).get("fileExtensions", [])
|
||||
if not file_extensions:
|
||||
return node
|
||||
file_extensions = [file_extension.lower() for file_extension in file_extensions]
|
||||
node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
|
||||
node["data"]["fileExtensions"] = [ext.lower() for ext in file_extensions if ext in DOCUMENT_EXTENSIONS]
|
||||
return node
|
||||
|
||||
def _deal_knowledge_index(
|
||||
|
|
|
|||
|
|
@ -349,14 +349,10 @@ class BuiltinToolManageService:
|
|||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
encrypters = {}
|
||||
for provider in providers:
|
||||
credential_type = provider.credential_type
|
||||
if credential_type not in encrypters:
|
||||
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, provider, provider.provider, provider_controller
|
||||
)[0]
|
||||
encrypter = encrypters[credential_type]
|
||||
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, provider, provider.provider, provider_controller
|
||||
)
|
||||
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ class WorkflowConverter:
|
|||
new_app.updated_by = account.id
|
||||
db.session.add(new_app)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
workflow.app_id = new_app.id
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
import operator
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
import click
|
||||
|
|
@ -9,38 +9,106 @@ from core.helper import marketplace
|
|||
from core.helper.marketplace import MarketplacePluginDeclaration
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
|
||||
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
|
||||
CACHE_REDIS_TTL = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
|
||||
def _get_redis_cache_key(plugin_id: str) -> str:
|
||||
"""Generate Redis cache key for plugin manifest."""
|
||||
return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}"
|
||||
|
||||
|
||||
def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]:
|
||||
"""
|
||||
Get cached plugin manifest from Redis.
|
||||
Returns:
|
||||
- MarketplacePluginDeclaration: if found in cache
|
||||
- None: if cached as not found (marketplace returned no result)
|
||||
- False: if not in cache at all
|
||||
"""
|
||||
try:
|
||||
key = _get_redis_cache_key(plugin_id)
|
||||
cached_data = redis_client.get(key)
|
||||
if cached_data is None:
|
||||
return False
|
||||
|
||||
cached_json = json.loads(cached_data)
|
||||
if cached_json is None:
|
||||
return None
|
||||
|
||||
return MarketplacePluginDeclaration.model_validate(cached_json)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None:
|
||||
"""
|
||||
Cache plugin manifest in Redis.
|
||||
Args:
|
||||
plugin_id: The plugin ID
|
||||
manifest: The manifest to cache, or None if not found in marketplace
|
||||
"""
|
||||
try:
|
||||
key = _get_redis_cache_key(plugin_id)
|
||||
if manifest is None:
|
||||
# Cache the fact that this plugin was not found
|
||||
redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None))
|
||||
else:
|
||||
# Cache the manifest data
|
||||
redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json())
|
||||
except Exception:
|
||||
# If Redis fails, continue without caching
|
||||
# traceback.print_exc()
|
||||
pass
|
||||
|
||||
|
||||
def marketplace_batch_fetch_plugin_manifests(
|
||||
plugin_ids_plain_list: list[str],
|
||||
) -> list[MarketplacePluginDeclaration]:
|
||||
global cached_plugin_manifests
|
||||
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
not_included_plugin_ids = [
|
||||
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
|
||||
]
|
||||
if not_included_plugin_ids:
|
||||
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
|
||||
"""Fetch plugin manifests with Redis caching support."""
|
||||
cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
|
||||
not_cached_plugin_ids: list[str] = []
|
||||
|
||||
# Check Redis cache for each plugin
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
cached_result = _get_cached_manifest(plugin_id)
|
||||
if cached_result is False:
|
||||
# Not in cache, need to fetch
|
||||
not_cached_plugin_ids.append(plugin_id)
|
||||
else:
|
||||
# Either found manifest or cached as None (not found in marketplace)
|
||||
# At this point, cached_result is either MarketplacePluginDeclaration or None
|
||||
if isinstance(cached_result, bool):
|
||||
# This should never happen due to the if condition above, but for type safety
|
||||
continue
|
||||
cached_manifests[plugin_id] = cached_result
|
||||
|
||||
# Fetch uncached plugins from marketplace
|
||||
if not_cached_plugin_ids:
|
||||
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids)
|
||||
|
||||
# Cache the fetched manifests
|
||||
for manifest in manifests:
|
||||
cached_plugin_manifests[manifest.plugin_id] = manifest
|
||||
cached_manifests[manifest.plugin_id] = manifest
|
||||
_set_cached_manifest(manifest.plugin_id, manifest)
|
||||
|
||||
if (
|
||||
len(manifests) == 0
|
||||
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
|
||||
for plugin_id in not_included_plugin_ids:
|
||||
cached_plugin_manifests[plugin_id] = None
|
||||
# Cache plugins that were not found in marketplace
|
||||
fetched_plugin_ids = {manifest.plugin_id for manifest in manifests}
|
||||
for plugin_id in not_cached_plugin_ids:
|
||||
if plugin_id not in fetched_plugin_ids:
|
||||
cached_manifests[plugin_id] = None
|
||||
_set_cached_manifest(plugin_id, None)
|
||||
|
||||
# Build result list from cached manifests
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
final_manifest = cached_plugin_manifests.get(plugin_id)
|
||||
if final_manifest is not None:
|
||||
result.append(final_manifest)
|
||||
cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id)
|
||||
if cached_manifest is not None:
|
||||
result.append(cached_manifest)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -157,10 +225,10 @@ def process_tenant_plugin_autoupgrade_check_task(
|
|||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
|
||||
traceback.print_exc()
|
||||
# traceback.print_exc()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red"))
|
||||
traceback.print_exc()
|
||||
# traceback.print_exc()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -29,23 +29,10 @@ def priority_rag_pipeline_run_task(
|
|||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
Async Run rag pipeline task using high priority queue.
|
||||
|
||||
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
|
||||
:param tenant_id: Tenant ID for the pipeline execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
|
|
|
|||
|
|
@ -30,23 +30,10 @@ def rag_pipeline_run_task(
|
|||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
Async Run rag pipeline task using regular priority queue.
|
||||
|
||||
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
|
||||
:param tenant_id: Tenant ID for the pipeline execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
|
|
|
|||
|
|
@ -5,15 +5,10 @@ These tasks provide asynchronous storage capabilities for workflow execution dat
|
|||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,316 @@
|
|||
app:
|
||||
description: 'This chatflow receives a sys.query, writes it into the `answer` variable,
|
||||
and then outputs the `answer` variable.
|
||||
|
||||
|
||||
`answer` is a conversation variable with a blank default value; it will be updated
|
||||
in an iteration node.
|
||||
|
||||
|
||||
if this chatflow works correctly, it will output the `sys.query` as the same.'
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: update-conversation-variable-in-iteration
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.4.0
|
||||
workflow:
|
||||
conversation_variables:
|
||||
- description: ''
|
||||
id: c30af82d-b2ec-417d-a861-4dd78584faa4
|
||||
name: answer
|
||||
selector:
|
||||
- conversation
|
||||
- answer
|
||||
value: ''
|
||||
value_type: string
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_size_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: code
|
||||
id: 1759032354471-source-1759032363865-target
|
||||
source: '1759032354471'
|
||||
sourceHandle: source
|
||||
target: '1759032363865'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: code
|
||||
targetType: iteration
|
||||
id: 1759032363865-source-1759032379989-target
|
||||
source: '1759032363865'
|
||||
sourceHandle: source
|
||||
target: '1759032379989'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: '1759032379989'
|
||||
sourceType: iteration-start
|
||||
targetType: assigner
|
||||
id: 1759032379989start-source-1759032394460-target
|
||||
source: 1759032379989start
|
||||
sourceHandle: source
|
||||
target: '1759032394460'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 1002
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: iteration
|
||||
targetType: answer
|
||||
id: 1759032379989-source-1759032410331-target
|
||||
source: '1759032379989'
|
||||
sourceHandle: source
|
||||
target: '1759032410331'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: '1759032379989'
|
||||
sourceType: assigner
|
||||
targetType: code
|
||||
id: 1759032394460-source-1759032476318-target
|
||||
source: '1759032394460'
|
||||
sourceHandle: source
|
||||
target: '1759032476318'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 1002
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: Start
|
||||
type: start
|
||||
variables: []
|
||||
height: 52
|
||||
id: '1759032354471'
|
||||
position:
|
||||
x: 30
|
||||
y: 302
|
||||
positionAbsolute:
|
||||
x: 30
|
||||
y: 302
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
code: "\ndef main():\n return {\n \"result\": [1],\n }\n"
|
||||
code_language: python3
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: array[number]
|
||||
selected: false
|
||||
title: Code
|
||||
type: code
|
||||
variables: []
|
||||
height: 52
|
||||
id: '1759032363865'
|
||||
position:
|
||||
x: 332
|
||||
y: 302
|
||||
positionAbsolute:
|
||||
x: 332
|
||||
y: 302
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
error_handle_mode: terminated
|
||||
height: 204
|
||||
is_parallel: false
|
||||
iterator_input_type: array[number]
|
||||
iterator_selector:
|
||||
- '1759032363865'
|
||||
- result
|
||||
output_selector:
|
||||
- '1759032476318'
|
||||
- result
|
||||
output_type: array[string]
|
||||
parallel_nums: 10
|
||||
selected: false
|
||||
start_node_id: 1759032379989start
|
||||
title: Iteration
|
||||
type: iteration
|
||||
width: 808
|
||||
height: 204
|
||||
id: '1759032379989'
|
||||
position:
|
||||
x: 634
|
||||
y: 302
|
||||
positionAbsolute:
|
||||
x: 634
|
||||
y: 302
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 808
|
||||
zIndex: 1
|
||||
- data:
|
||||
desc: ''
|
||||
isInIteration: true
|
||||
selected: false
|
||||
title: ''
|
||||
type: iteration-start
|
||||
draggable: false
|
||||
height: 48
|
||||
id: 1759032379989start
|
||||
parentId: '1759032379989'
|
||||
position:
|
||||
x: 60
|
||||
y: 78
|
||||
positionAbsolute:
|
||||
x: 694
|
||||
y: 380
|
||||
selectable: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom-iteration-start
|
||||
width: 44
|
||||
zIndex: 1002
|
||||
- data:
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
items:
|
||||
- input_type: variable
|
||||
operation: over-write
|
||||
value:
|
||||
- sys
|
||||
- query
|
||||
variable_selector:
|
||||
- conversation
|
||||
- answer
|
||||
write_mode: over-write
|
||||
iteration_id: '1759032379989'
|
||||
selected: false
|
||||
title: Variable Assigner
|
||||
type: assigner
|
||||
version: '2'
|
||||
height: 84
|
||||
id: '1759032394460'
|
||||
parentId: '1759032379989'
|
||||
position:
|
||||
x: 204
|
||||
y: 60
|
||||
positionAbsolute:
|
||||
x: 838
|
||||
y: 362
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
zIndex: 1002
|
||||
- data:
|
||||
answer: '{{#conversation.answer#}}'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 104
|
||||
id: '1759032410331'
|
||||
position:
|
||||
x: 1502
|
||||
y: 302
|
||||
positionAbsolute:
|
||||
x: 1502
|
||||
y: 302
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
code: "\ndef main():\n return {\n \"result\": '',\n }\n"
|
||||
code_language: python3
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: '1759032379989'
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: string
|
||||
selected: false
|
||||
title: Code 2
|
||||
type: code
|
||||
variables: []
|
||||
height: 52
|
||||
id: '1759032476318'
|
||||
parentId: '1759032379989'
|
||||
position:
|
||||
x: 506
|
||||
y: 76
|
||||
positionAbsolute:
|
||||
x: 1140
|
||||
y: 378
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
zIndex: 1002
|
||||
viewport:
|
||||
x: 120.39999999999998
|
||||
y: 85.20000000000005
|
||||
zoom: 0.7
|
||||
rag_pipeline_variables: []
|
||||
|
|
@ -11,8 +11,8 @@ from controllers.console.app import completion as completion_api
|
|||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
|
@ -31,9 +31,8 @@ class TestChatMessageApiPermissions:
|
|||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
|
|
@ -42,12 +41,24 @@ class TestChatMessageApiPermissions:
|
|||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -18,124 +18,87 @@ class TestAppDescriptionValidationUnit:
|
|||
"""Unit tests for description validation function"""
|
||||
|
||||
def test_validate_description_length_function(self):
|
||||
"""Test the _validate_description_length function directly"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
"""Test the validate_description_length function directly"""
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test valid descriptions
|
||||
assert _validate_description_length("") == ""
|
||||
assert _validate_description_length("x" * 400) == "x" * 400
|
||||
assert _validate_description_length(None) is None
|
||||
assert validate_description_length("") == ""
|
||||
assert validate_description_length("x" * 400) == "x" * 400
|
||||
assert validate_description_length(None) is None
|
||||
|
||||
# Test invalid descriptions
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 401)
|
||||
validate_description_length("x" * 401)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 500)
|
||||
validate_description_length("x" * 500)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 1000)
|
||||
validate_description_length("x" * 1000)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_validation_consistency_with_dataset(self):
|
||||
"""Test that App and Dataset validation functions are consistent"""
|
||||
from controllers.console.app.app import _validate_description_length as app_validate
|
||||
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
|
||||
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
|
||||
|
||||
# Test same valid inputs
|
||||
valid_desc = "x" * 400
|
||||
assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc)
|
||||
assert app_validate("") == dataset_validate("") == service_dataset_validate("")
|
||||
assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None)
|
||||
|
||||
# Test same invalid inputs produce same error
|
||||
invalid_desc = "x" * 401
|
||||
|
||||
app_error = None
|
||||
dataset_error = None
|
||||
service_dataset_error = None
|
||||
|
||||
try:
|
||||
app_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
app_error = str(e)
|
||||
|
||||
try:
|
||||
dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
dataset_error = str(e)
|
||||
|
||||
try:
|
||||
service_dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
service_dataset_error = str(e)
|
||||
|
||||
assert app_error == dataset_error == service_dataset_error
|
||||
assert app_error == "Description cannot exceed 400 characters."
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test exact boundary
|
||||
exactly_400 = "x" * 400
|
||||
assert _validate_description_length(exactly_400) == exactly_400
|
||||
assert validate_description_length(exactly_400) == exactly_400
|
||||
|
||||
# Test just over boundary
|
||||
just_over_400 = "x" * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(just_over_400)
|
||||
validate_description_length(just_over_400)
|
||||
|
||||
# Test just under boundary
|
||||
just_under_400 = "x" * 399
|
||||
assert _validate_description_length(just_under_400) == just_under_400
|
||||
assert validate_description_length(just_under_400) == just_under_400
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test None input
|
||||
assert _validate_description_length(None) is None
|
||||
assert validate_description_length(None) is None
|
||||
|
||||
# Test empty string
|
||||
assert _validate_description_length("") == ""
|
||||
assert validate_description_length("") == ""
|
||||
|
||||
# Test single character
|
||||
assert _validate_description_length("a") == "a"
|
||||
assert validate_description_length("a") == "a"
|
||||
|
||||
# Test unicode characters
|
||||
unicode_desc = "测试" * 200 # 400 characters in Chinese
|
||||
assert _validate_description_length(unicode_desc) == unicode_desc
|
||||
assert validate_description_length(unicode_desc) == unicode_desc
|
||||
|
||||
# Test unicode over limit
|
||||
unicode_over = "测试" * 201 # 402 characters
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(unicode_over)
|
||||
validate_description_length(unicode_over)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test how validation handles whitespace"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test description with spaces
|
||||
spaces_400 = " " * 400
|
||||
assert _validate_description_length(spaces_400) == spaces_400
|
||||
assert validate_description_length(spaces_400) == spaces_400
|
||||
|
||||
# Test description with spaces over limit
|
||||
spaces_401 = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(spaces_401)
|
||||
validate_description_length(spaces_401)
|
||||
|
||||
# Test mixed content
|
||||
mixed_400 = "a" * 200 + " " * 200
|
||||
assert _validate_description_length(mixed_400) == mixed_400
|
||||
assert validate_description_length(mixed_400) == mixed_400
|
||||
|
||||
# Test mixed over limit
|
||||
mixed_401 = "a" * 200 + " " * 201
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(mixed_401)
|
||||
validate_description_length(mixed_401)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from flask.testing import FlaskClient
|
|||
from controllers.console.app import model_config as model_config_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
|
@ -30,9 +30,8 @@ class TestModelConfigResourcePermissions:
|
|||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
|
|
@ -41,12 +40,24 @@ class TestModelConfigResourcePermissions:
|
|||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import time
|
||||
import uuid
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
|
|
@ -15,7 +15,7 @@ from core.workflow.system_variable import SystemVariable
|
|||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
|
||||
CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
|
||||
|
||||
def init_code_node(code_config: dict):
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from flask.testing import FlaskClient
|
|||
from sqlalchemy import Engine, text
|
||||
from sqlalchemy.orm import Session
|
||||
from testcontainers.core.container import DockerContainer
|
||||
from testcontainers.core.network import Network
|
||||
from testcontainers.core.waiting_utils import wait_for_logs
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
|
@ -41,6 +42,7 @@ class DifyTestContainers:
|
|||
|
||||
def __init__(self):
|
||||
"""Initialize container management with default configurations."""
|
||||
self.network: Network | None = None
|
||||
self.postgres: PostgresContainer | None = None
|
||||
self.redis: RedisContainer | None = None
|
||||
self.dify_sandbox: DockerContainer | None = None
|
||||
|
|
@ -62,12 +64,18 @@ class DifyTestContainers:
|
|||
|
||||
logger.info("Starting test containers for Dify integration tests...")
|
||||
|
||||
# Create Docker network for container communication
|
||||
logger.info("Creating Docker network for container communication...")
|
||||
self.network = Network()
|
||||
self.network.create()
|
||||
logger.info("Docker network created successfully with name: %s", self.network.name)
|
||||
|
||||
# Start PostgreSQL container for main application database
|
||||
# PostgreSQL is used for storing user data, workflows, and application state
|
||||
logger.info("Initializing PostgreSQL container...")
|
||||
self.postgres = PostgresContainer(
|
||||
image="postgres:14-alpine",
|
||||
)
|
||||
).with_network(self.network)
|
||||
self.postgres.start()
|
||||
db_host = self.postgres.get_container_host_ip()
|
||||
db_port = self.postgres.get_exposed_port(5432)
|
||||
|
|
@ -137,7 +145,7 @@ class DifyTestContainers:
|
|||
# Start Redis container for caching and session management
|
||||
# Redis is used for storing session data, cache entries, and temporary data
|
||||
logger.info("Initializing Redis container...")
|
||||
self.redis = RedisContainer(image="redis:6-alpine", port=6379)
|
||||
self.redis = RedisContainer(image="redis:6-alpine", port=6379).with_network(self.network)
|
||||
self.redis.start()
|
||||
redis_host = self.redis.get_container_host_ip()
|
||||
redis_port = self.redis.get_exposed_port(6379)
|
||||
|
|
@ -153,7 +161,7 @@ class DifyTestContainers:
|
|||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
|
|
@ -173,22 +181,28 @@ class DifyTestContainers:
|
|||
# Start Dify Plugin Daemon container for plugin management
|
||||
# Dify Plugin Daemon provides plugin lifecycle management and execution
|
||||
logger.info("Initializing Dify Plugin Daemon container...")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local").with_network(
|
||||
self.network
|
||||
)
|
||||
self.dify_plugin_daemon.with_exposed_ports(5002)
|
||||
# Get container internal network addresses
|
||||
postgres_container_name = self.postgres.get_wrapped_container().name
|
||||
redis_container_name = self.redis.get_wrapped_container().name
|
||||
|
||||
self.dify_plugin_daemon.env = {
|
||||
"DB_HOST": db_host,
|
||||
"DB_PORT": str(db_port),
|
||||
"DB_HOST": postgres_container_name, # Use container name for internal network communication
|
||||
"DB_PORT": "5432", # Use internal port
|
||||
"DB_USERNAME": self.postgres.username,
|
||||
"DB_PASSWORD": self.postgres.password,
|
||||
"DB_DATABASE": "dify_plugin",
|
||||
"REDIS_HOST": redis_host,
|
||||
"REDIS_PORT": str(redis_port),
|
||||
"REDIS_HOST": redis_container_name, # Use container name for internal network communication
|
||||
"REDIS_PORT": "6379", # Use internal port
|
||||
"REDIS_PASSWORD": "",
|
||||
"SERVER_PORT": "5002",
|
||||
"SERVER_KEY": "test_plugin_daemon_key",
|
||||
"MAX_PLUGIN_PACKAGE_SIZE": "52428800",
|
||||
"PPROF_ENABLED": "false",
|
||||
"DIFY_INNER_API_URL": f"http://{db_host}:5001",
|
||||
"DIFY_INNER_API_URL": f"http://{postgres_container_name}:5001",
|
||||
"DIFY_INNER_API_KEY": "test_inner_api_key",
|
||||
"PLUGIN_REMOTE_INSTALLING_HOST": "0.0.0.0",
|
||||
"PLUGIN_REMOTE_INSTALLING_PORT": "5003",
|
||||
|
|
@ -253,6 +267,15 @@ class DifyTestContainers:
|
|||
# Log error but don't fail the test cleanup
|
||||
logger.warning("Failed to stop container %s: %s", container, e)
|
||||
|
||||
# Stop and remove the network
|
||||
if self.network:
|
||||
try:
|
||||
logger.info("Removing Docker network...")
|
||||
self.network.remove()
|
||||
logger.info("Successfully removed Docker network")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove Docker network: %s", e)
|
||||
|
||||
self._containers_started = False
|
||||
logger.info("All test containers stopped and cleaned up successfully")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import EndUser
|
||||
|
|
@ -484,36 +483,6 @@ class TestAppGenerateService:
|
|||
# Verify error message
|
||||
assert "Rate limit exceeded" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_rate_limit_error_from_openai(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when OpenAI rate limit error occurs.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Setup completion generator to raise RateLimitError
|
||||
mock_response = MagicMock()
|
||||
mock_response.request = MagicMock()
|
||||
mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = RateLimitError(
|
||||
"Rate limit exceeded", response=mock_response, body=None
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect rate limit error
|
||||
with pytest.raises(InvokeRateLimitError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Rate limit exceeded" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with invalid app mode.
|
||||
|
|
|
|||
|
|
@ -784,133 +784,6 @@ class TestCleanDatasetTask:
|
|||
print(f"Total cleanup time: {cleanup_duration:.3f} seconds")
|
||||
print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds")
|
||||
|
||||
def test_clean_dataset_task_concurrent_cleanup_scenarios(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test dataset cleanup with concurrent cleanup scenarios and race conditions.
|
||||
|
||||
This test verifies that the task can properly:
|
||||
1. Handle multiple cleanup operations on the same dataset
|
||||
2. Prevent data corruption during concurrent access
|
||||
3. Maintain data consistency across multiple cleanup attempts
|
||||
4. Handle race conditions gracefully
|
||||
5. Ensure idempotent cleanup operations
|
||||
"""
|
||||
# Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, tenant)
|
||||
document = self._create_test_document(db_session_with_containers, account, tenant, dataset)
|
||||
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
|
||||
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
|
||||
|
||||
# Update document with file reference
|
||||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Save IDs for verification
|
||||
dataset_id = dataset.id
|
||||
tenant_id = tenant.id
|
||||
upload_file_id = upload_file.id
|
||||
|
||||
# Mock storage to simulate slow operations
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
original_delete = mock_storage.delete
|
||||
|
||||
def slow_delete(key):
|
||||
import time
|
||||
|
||||
time.sleep(0.1) # Simulate slow storage operation
|
||||
return original_delete(key)
|
||||
|
||||
mock_storage.delete.side_effect = slow_delete
|
||||
|
||||
# Execute multiple cleanup operations concurrently
|
||||
import threading
|
||||
|
||||
cleanup_results = []
|
||||
cleanup_errors = []
|
||||
|
||||
def run_cleanup():
|
||||
try:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
cleanup_results.append("success")
|
||||
except Exception as e:
|
||||
cleanup_errors.append(str(e))
|
||||
|
||||
# Start multiple cleanup threads
|
||||
threads = []
|
||||
for i in range(3):
|
||||
thread = threading.Thread(target=run_cleanup)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted (only once)
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset_id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted (only once)
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset_id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that upload file was deleted (only once)
|
||||
# Note: In concurrent scenarios, the first thread deletes documents and segments,
|
||||
# subsequent threads may not find the related data to clean up upload files
|
||||
# This demonstrates the idempotent nature of the cleanup process
|
||||
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
# The upload file should be deleted by the first successful cleanup operation
|
||||
# However, in concurrent scenarios, this may not always happen due to race conditions
|
||||
# This test demonstrates the idempotent nature of the cleanup process
|
||||
if len(remaining_files) > 0:
|
||||
print(f"Warning: Upload file {upload_file_id} was not deleted in concurrent scenario")
|
||||
print("This is expected behavior demonstrating the idempotent nature of cleanup")
|
||||
# We don't assert here as the behavior depends on timing and race conditions
|
||||
|
||||
# Verify that storage.delete was called (may be called multiple times in concurrent scenarios)
|
||||
# In concurrent scenarios, storage operations may be called multiple times due to race conditions
|
||||
assert mock_storage.delete.call_count > 0
|
||||
|
||||
# Verify that index processor was called (may be called multiple times in concurrent scenarios)
|
||||
mock_index_processor = mock_external_service_dependencies["index_processor"]
|
||||
assert mock_index_processor.clean.call_count > 0
|
||||
|
||||
# Check cleanup results
|
||||
assert len(cleanup_results) == 3, "All cleanup operations should complete"
|
||||
assert len(cleanup_errors) == 0, "No cleanup errors should occur"
|
||||
|
||||
# Verify idempotency by running cleanup again on the same dataset
|
||||
# This should not perform any additional operations since data is already cleaned
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Verify that no additional storage operations were performed
|
||||
# Note: In concurrent scenarios, the exact count may vary due to race conditions
|
||||
print(f"Final storage delete calls: {mock_storage.delete.call_count}")
|
||||
print(f"Final index processor calls: {mock_index_processor.clean.call_count}")
|
||||
print("Note: Multiple calls in concurrent scenarios are expected due to race conditions")
|
||||
|
||||
def test_clean_dataset_task_storage_exception_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,450 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
|
||||
|
||||
class TestEnableSegmentsToIndexTask:
|
||||
"""Integration tests for enable_segments_to_index_task using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
):
|
||||
# Setup mock index processor
|
||||
mock_processor = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
yield {
|
||||
"index_processor_factory": mock_index_processor_factory,
|
||||
"index_processor": mock_processor,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test dataset and document for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, document) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create document
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
doc_form=IndexType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property works correctly
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, document
|
||||
|
||||
def _create_test_segments(
|
||||
self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed"
|
||||
):
|
||||
"""
|
||||
Helper method to create test document segments.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
document: Document instance
|
||||
dataset: Dataset instance
|
||||
count: Number of segments to create
|
||||
enabled: Whether segments should be enabled
|
||||
status: Status of the segments
|
||||
|
||||
Returns:
|
||||
list: List of created DocumentSegment instances
|
||||
"""
|
||||
fake = Faker()
|
||||
segments = []
|
||||
|
||||
for i in range(count):
|
||||
text = fake.text(max_nb_chars=200)
|
||||
segment = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
position=i,
|
||||
content=text,
|
||||
word_count=len(text.split()),
|
||||
tokens=len(text.split()) * 2,
|
||||
index_node_id=f"node_{i}",
|
||||
index_node_hash=f"hash_{i}",
|
||||
enabled=enabled,
|
||||
status=status,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
return segments
|
||||
|
||||
def test_enable_segments_to_index_with_different_index_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test segments indexing with different index types.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of different index types
|
||||
- Index processor factory integration
|
||||
- Document processing with various configurations
|
||||
- Redis cache key deletion
|
||||
"""
|
||||
# Arrange: Create test data with different index type
|
||||
dataset, document = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update document to use different index type
|
||||
document.doc_form = IndexType.QA_INDEX
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# Create segments
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
|
||||
# Set up Redis cache keys
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Act: Execute the task
|
||||
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert: Verify different index type handling
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
|
||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||
|
||||
# Verify the load method was called with correct parameters
|
||||
call_args = mock_external_service_dependencies["index_processor"].load.call_args
|
||||
assert call_args is not None
|
||||
documents = call_args[0][1] # Second argument should be documents list
|
||||
assert len(documents) == 3
|
||||
|
||||
# Verify Redis cache keys were deleted
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_enable_segments_to_index_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent dataset.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing datasets
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- No unnecessary index processor calls
|
||||
"""
|
||||
# Arrange: Use non-existent dataset ID
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
non_existent_document_id = fake.uuid4()
|
||||
segment_ids = [fake.uuid4()]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
enable_segments_to_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_document_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent document.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing documents
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- No unnecessary index processor calls
|
||||
"""
|
||||
# Arrange: Create dataset but use non-existent document ID
|
||||
dataset, _ = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
fake = Faker()
|
||||
non_existent_document_id = fake.uuid4()
|
||||
segment_ids = [fake.uuid4()]
|
||||
|
||||
# Act: Execute the task with non-existent document
|
||||
enable_segments_to_index_task(segment_ids, dataset.id, non_existent_document_id)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_invalid_document_status(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of document with invalid status.
|
||||
|
||||
This test verifies:
|
||||
- Early return when document is disabled, archived, or not completed
|
||||
- No index processing for documents not ready for indexing
|
||||
- Proper database session cleanup
|
||||
- No unnecessary external service calls
|
||||
"""
|
||||
# Arrange: Create test data with invalid document status
|
||||
dataset, document = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test different invalid statuses
|
||||
invalid_statuses = [
|
||||
("disabled", {"enabled": False}),
|
||||
("archived", {"archived": True}),
|
||||
("not_completed", {"indexing_status": "processing"}),
|
||||
]
|
||||
|
||||
for _, status_attrs in invalid_statuses:
|
||||
# Reset document status
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.indexing_status = "completed"
|
||||
db.session.commit()
|
||||
|
||||
# Set invalid status
|
||||
for attr, value in status_attrs.items():
|
||||
setattr(document, attr, value)
|
||||
db.session.commit()
|
||||
|
||||
# Create segments
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Act: Execute the task
|
||||
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
# Clean up segments for next iteration
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
def test_enable_segments_to_index_segments_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when no segments are found.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when segments don't exist
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- Index processor is created but load is not called
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, document = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Use non-existent segment IDs
|
||||
fake = Faker()
|
||||
non_existent_segment_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent segments
|
||||
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert: Verify index processor was created but load was not called
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_with_parent_child_structure(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test segments indexing with parent-child structure.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of PARENT_CHILD_INDEX type
|
||||
- Child document creation from segments
|
||||
- Correct document structure for parent-child indexing
|
||||
- Index processor receives properly structured documents
|
||||
- Redis cache key deletion
|
||||
"""
|
||||
# Arrange: Create test data with parent-child index type
|
||||
dataset, document = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update document to use parent-child index type
|
||||
document.doc_form = IndexType.PARENT_CHILD_INDEX
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# Create segments with mock child chunks
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
|
||||
# Set up Redis cache keys
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Mock the get_child_chunks method for each segment
|
||||
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
|
||||
# Setup mock to return child chunks for each segment
|
||||
mock_child_chunks = []
|
||||
for i in range(2): # Each segment has 2 child chunks
|
||||
mock_child = MagicMock()
|
||||
mock_child.content = f"child_content_{i}"
|
||||
mock_child.index_node_id = f"child_node_{i}"
|
||||
mock_child.index_node_hash = f"child_hash_{i}"
|
||||
mock_child_chunks.append(mock_child)
|
||||
|
||||
mock_get_child_chunks.return_value = mock_child_chunks
|
||||
|
||||
# Act: Execute the task
|
||||
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert: Verify parent-child index processing
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||
IndexType.PARENT_CHILD_INDEX
|
||||
)
|
||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||
|
||||
# Verify the load method was called with correct parameters
|
||||
call_args = mock_external_service_dependencies["index_processor"].load.call_args
|
||||
assert call_args is not None
|
||||
documents = call_args[0][1] # Second argument should be documents list
|
||||
assert len(documents) == 3 # 3 segments
|
||||
|
||||
# Verify each document has children
|
||||
for doc in documents:
|
||||
assert hasattr(doc, "children")
|
||||
assert len(doc.children) == 2 # Each document has 2 children
|
||||
|
||||
# Verify Redis cache keys were deleted
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_enable_segments_to_index_general_exception_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test general exception handling during indexing process.
|
||||
|
||||
This test verifies:
|
||||
- Exceptions are properly caught and handled
|
||||
- Segment status is set to error
|
||||
- Segments are disabled
|
||||
- Error information is recorded
|
||||
- Redis cache is still cleared
|
||||
- Database session is properly closed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, document = self._create_test_dataset_and_document(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
|
||||
# Set up Redis cache keys
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Mock the index processor to raise an exception
|
||||
mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Index processing failed")
|
||||
|
||||
# Act: Execute the task
|
||||
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
assert segment.enabled is False
|
||||
assert segment.status == "error"
|
||||
assert segment.error is not None
|
||||
assert "Index processing failed" in segment.error
|
||||
assert segment.disabled_at is not None
|
||||
|
||||
# Verify Redis cache keys were still cleared despite error
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.email_i18n import EmailType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
|
||||
|
||||
|
||||
class TestMailAccountDeletionTask:
|
||||
"""Integration tests for mail account deletion tasks using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_account_deletion_task.mail") as mock_mail,
|
||||
patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_service.return_value = mock_email_service
|
||||
|
||||
yield {
|
||||
"mail": mock_mail,
|
||||
"get_email_service": mock_get_email_service,
|
||||
"email_service": mock_email_service,
|
||||
}
|
||||
|
||||
def _create_test_account(self, db_session_with_containers):
|
||||
"""
|
||||
Helper method to create a test account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
|
||||
Returns:
|
||||
Account: Created account instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
return account
|
||||
|
||||
def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful account deletion success email sending.
|
||||
|
||||
This test verifies:
|
||||
- Proper email service initialization check
|
||||
- Correct email service method calls
|
||||
- Template context is properly formatted
|
||||
- Email type is correctly specified
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
test_language = "en-US"
|
||||
|
||||
# Act: Execute the task
|
||||
send_deletion_success_task(test_email, test_language)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify mail service was checked
|
||||
mock_external_service_dependencies["mail"].is_inited.assert_called_once()
|
||||
|
||||
# Verify email service was retrieved
|
||||
mock_external_service_dependencies["get_email_service"].assert_called_once()
|
||||
|
||||
# Verify email was sent with correct parameters
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_called_once_with(
|
||||
email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
|
||||
language_code=test_language,
|
||||
to=test_email,
|
||||
template_context={
|
||||
"to": test_email,
|
||||
"email": test_email,
|
||||
},
|
||||
)
|
||||
|
||||
def test_send_deletion_success_task_mail_not_initialized(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion success email when mail service is not initialized.
|
||||
|
||||
This test verifies:
|
||||
- Early return when mail service is not initialized
|
||||
- No email service calls are made
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Setup mail service to return not initialized
|
||||
mock_external_service_dependencies["mail"].is_inited.return_value = False
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
|
||||
# Act: Execute the task
|
||||
send_deletion_success_task(test_email)
|
||||
|
||||
# Assert: Verify no email service calls were made
|
||||
mock_external_service_dependencies["get_email_service"].assert_not_called()
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
|
||||
|
||||
def test_send_deletion_success_task_email_service_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion success email when email service raises exception.
|
||||
|
||||
This test verifies:
|
||||
- Exception is properly caught and logged
|
||||
- Task completes without raising exception
|
||||
- Error logging is recorded
|
||||
"""
|
||||
# Arrange: Setup email service to raise exception
|
||||
mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed")
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
|
||||
# Act: Execute the task (should not raise exception)
|
||||
send_deletion_success_task(test_email)
|
||||
|
||||
# Assert: Verify email service was called but exception was handled
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_called_once()
|
||||
|
||||
def test_send_account_deletion_verification_code_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful account deletion verification code email sending.
|
||||
|
||||
This test verifies:
|
||||
- Proper email service initialization check
|
||||
- Correct email service method calls
|
||||
- Template context includes verification code
|
||||
- Email type is correctly specified
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
test_code = "123456"
|
||||
test_language = "en-US"
|
||||
|
||||
# Act: Execute the task
|
||||
send_account_deletion_verification_code(test_email, test_code, test_language)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify mail service was checked
|
||||
mock_external_service_dependencies["mail"].is_inited.assert_called_once()
|
||||
|
||||
# Verify email service was retrieved
|
||||
mock_external_service_dependencies["get_email_service"].assert_called_once()
|
||||
|
||||
# Verify email was sent with correct parameters
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_called_once_with(
|
||||
email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
|
||||
language_code=test_language,
|
||||
to=test_email,
|
||||
template_context={
|
||||
"to": test_email,
|
||||
"code": test_code,
|
||||
},
|
||||
)
|
||||
|
||||
def test_send_account_deletion_verification_code_mail_not_initialized(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion verification code email when mail service is not initialized.
|
||||
|
||||
This test verifies:
|
||||
- Early return when mail service is not initialized
|
||||
- No email service calls are made
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Setup mail service to return not initialized
|
||||
mock_external_service_dependencies["mail"].is_inited.return_value = False
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
test_code = "123456"
|
||||
|
||||
# Act: Execute the task
|
||||
send_account_deletion_verification_code(test_email, test_code)
|
||||
|
||||
# Assert: Verify no email service calls were made
|
||||
mock_external_service_dependencies["get_email_service"].assert_not_called()
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
|
||||
|
||||
def test_send_account_deletion_verification_code_email_service_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion verification code email when email service raises exception.
|
||||
|
||||
This test verifies:
|
||||
- Exception is properly caught and logged
|
||||
- Task completes without raising exception
|
||||
- Error logging is recorded
|
||||
"""
|
||||
# Arrange: Setup email service to raise exception
|
||||
mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed")
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
test_email = account.email
|
||||
test_code = "123456"
|
||||
|
||||
# Act: Execute the task (should not raise exception)
|
||||
send_account_deletion_verification_code(test_email, test_code)
|
||||
|
||||
# Assert: Verify email service was called but exception was handled
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_called_once()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue