Merge branch 'test/workflow-part-8' into test/workflow-app

This commit is contained in:
CodingOnStar 2026-03-25 15:21:27 +08:00
commit 77e7f0a7de
172 changed files with 6254 additions and 3076 deletions

View File

@ -10,7 +10,7 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
@ -86,7 +86,7 @@ def migrate_annotation_vector_database():
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
@ -178,7 +178,9 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
select(Dataset)
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
.order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
@ -355,7 +356,7 @@ class DatasetListApi(Resource):
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
@ -436,7 +437,7 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
data["embedding_model_provider"] = str(provider_id)
@ -454,7 +455,7 @@ class DatasetApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
@ -485,7 +486,7 @@ class DatasetApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource):
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
)
)
or 0
)
if current_key_count >= self.max_keys:
@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204

View File

@ -27,6 +27,7 @@ from core.model_manager import ModelManager
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from extensions.ext_database import db
@ -449,7 +450,7 @@ class DatasetInitApi(Resource):
raise Forbidden()
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
@ -463,7 +464,7 @@ class DatasetInitApi(Resource):
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -1337,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource):
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"

View File

@ -26,6 +26,7 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
if not upload_file:
raise NotFound("UploadFile not found.")
@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -15,6 +15,7 @@ from controllers.service_api.wraps import (
cloud_edition_billing_rate_limit_check,
)
from core.provider_manager import ProviderManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
@ -153,9 +154,14 @@ class DatasetListApi(DatasetApiResource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item_model in model_names:
item["embedding_available"] = True # type: ignore
else:
@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data.get("indexing_technique") == "high_quality":
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names:
data["embedding_available"] = True
@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource):
# check embedding model setting
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model

View File

@ -17,6 +17,7 @@ from controllers.service_api.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource):
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(

View File

@ -4,6 +4,7 @@ from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType, ConversationFromSource
@ -50,7 +51,7 @@ class AnnotationReplyFeature:
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
@ -271,7 +271,7 @@ class IndexingRunner:
doc_form: str | None = None,
doc_language: str = "English",
dataset_id: str | None = None,
indexing_technique: str = "economy",
indexing_technique: str = IndexTechniqueType.ECONOMY,
) -> IndexingEstimate:
"""
Estimate the indexing for the document.
@ -289,7 +289,7 @@ class IndexingRunner:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
@ -303,7 +303,7 @@ class IndexingRunner:
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
@ -573,7 +573,7 @@ class IndexingRunner:
"""
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
@ -587,7 +587,7 @@ class IndexingRunner:
create_keyword_thread = None
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
):
# create keyword index
create_keyword_thread = threading.Thread(
@ -597,7 +597,7 @@ class IndexingRunner:
create_keyword_thread.start()
max_workers = 10
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
@ -628,7 +628,7 @@ class IndexingRunner:
tokens += future.result()
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
and create_keyword_thread is not None
):
create_keyword_thread.join()
@ -654,7 +654,7 @@ class IndexingRunner:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
@ -764,7 +764,7 @@ class IndexingRunner:
) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,

View File

@ -6,6 +6,7 @@ from typing import Any
from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import AttachmentDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
@ -71,7 +72,7 @@ class DatasetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == "high_quality":
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@ -9,6 +9,7 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
@ -159,7 +160,7 @@ class IndexProcessor:
tenant_id = dataset.tenant_id
preview_output = self.format_preview(chunk_structure, chunks)
if indexing_technique != "high_quality":
if indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):

View File

@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
keyword = Keyword(dataset)
keyword.add_texts(documents)

View File

@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
for document in documents:
child_documents = document.children
@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
vector = Vector(dataset)
@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
all_child_documents = []
all_multimodal_documents = []
for doc in documents:

View File

@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor):
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
else:

View File

@ -675,7 +675,7 @@ class DatasetRetrieval:
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if selected_dataset.indexing_technique == "economy":
if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY:
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
@ -752,7 +752,7 @@ class DatasetRetrieval:
"The configured knowledge base list have different indexing technique, please set reranking model."
)
index_type = available_datasets[0].indexing_technique
if index_type == "high_quality":
if index_type == IndexTechniqueType.HIGH_QUALITY:
embedding_model_check = all(
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
)
@ -1068,7 +1068,7 @@ class DatasetRetrieval:
else default_retrieval_model
)
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,

View File

@ -2,6 +2,7 @@ import concurrent.futures
import logging
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
@ -21,7 +22,7 @@ class SummaryIndex:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset or dataset.indexing_technique != "high_quality":
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
if summary_index_setting is None:

View File

@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,

View File

@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
if dataset.indexing_technique != IndexTechniqueType.ECONOMY:
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]

View File

@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
@ -137,7 +137,7 @@ class Dataset(Base):
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -940,7 +940,9 @@ class AccountTrialAppRecord(Base):
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -1849,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source: Mapped[str] = mapped_column(LongText, nullable=False)

View File

@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.file import helpers as file_helpers
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
@ -228,7 +228,7 @@ class DatasetService:
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if embedding_model_provider and embedding_model_name:
# check if embedding model setting is valid
@ -254,7 +254,10 @@ class DatasetService:
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
dataset = Dataset(
name=name,
indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None,
)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.description = description
dataset.created_by = account.id
@ -349,7 +352,7 @@ class DatasetService:
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -717,13 +720,13 @@ class DatasetService:
if "indexing_technique" not in data:
return None
if dataset.indexing_technique != data["indexing_technique"]:
if data["indexing_technique"] == "economy":
if data["indexing_technique"] == IndexTechniqueType.ECONOMY:
# Remove embedding model configuration for economy mode
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
return "remove"
elif data["indexing_technique"] == "high_quality":
elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
# Configure embedding model for high quality mode
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
return "add"
@ -953,8 +956,8 @@ class DatasetService:
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, # ignore type error
@ -976,7 +979,7 @@ class DatasetService:
embedding_model_name,
)
dataset.collection_binding_id = dataset_collection_binding.id
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
else:
raise ValueError("Invalid index method")
@ -991,9 +994,9 @@ class DatasetService:
action = None
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique
if knowledge_configuration.indexing_technique == "economy":
if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.indexing_technique == "high_quality":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
action = "add"
# get embedding model setting
try:
@ -1018,7 +1021,7 @@ class DatasetService:
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@ -1029,7 +1032,7 @@ class DatasetService:
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
skip_embedding_update = False
try:
# Handle existing model provider
@ -1089,7 +1092,7 @@ class DatasetService:
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
@ -1907,8 +1910,8 @@ class DocumentService:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique)
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
@ -2689,7 +2692,7 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -2712,7 +2715,7 @@ class DocumentService:
tenant_id=tenant_id,
name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique),
created_by=account.id,
embedding_model=knowledge_config.embedding_model,
embedding_model_provider=knowledge_config.embedding_model_provider,
@ -3125,7 +3128,7 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3208,7 +3211,7 @@ class SegmentService:
try:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3230,7 +3233,7 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model:
# calc embedding use tokens
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
@ -3345,7 +3348,7 @@ class SegmentService:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -3382,7 +3385,7 @@ class SegmentService:
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
@ -3409,7 +3412,7 @@ class SegmentService:
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3449,7 +3452,7 @@ class SegmentService:
db.session.commit()
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -3481,7 +3484,7 @@ class SegmentService:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle summary index when content changed
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary
existing_summary = (

View File

@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.plugin.entities.plugin import PluginDependency
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
@ -311,13 +312,13 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@ -343,7 +344,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@ -443,18 +444,18 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@ -480,7 +481,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@ -772,7 +773,7 @@ class RagPipelineDslService:
)
case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE:
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(

View File

@ -9,7 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -105,29 +105,29 @@ class RagPipelineTransformService:
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.website-crawl-general-economy.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
@ -170,11 +170,11 @@ class RagPipelineTransformService:
):
knowledge_configuration_dict = node.get("data", {})
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_model
else:

View File

@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@ -140,7 +141,7 @@ class SummaryIndexService:
session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
If not provided, creates a new session and commits automatically.
"""
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.warning(
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
dataset.id,
@ -724,7 +725,7 @@ class SummaryIndexService:
List of created DocumentSegmentSummary instances
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
dataset.id,
@ -851,7 +852,7 @@ class SummaryIndexService:
)
# Remove from vector database (but keep records)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
try:
@ -889,7 +890,7 @@ class SummaryIndexService:
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
"""
# Only enable summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
with session_factory.create_session() as session:
@ -981,7 +982,7 @@ class SummaryIndexService:
return
# Delete from vector database
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
vector = Vector(dataset)
@ -1012,7 +1013,7 @@ class SummaryIndexService:
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
"""
# Only update summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return None
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist

View File

@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, Document
@ -45,7 +45,7 @@ class VectorService:
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -112,7 +112,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
@ -197,7 +197,7 @@ class VectorService:
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@ -237,7 +237,7 @@ class VectorService:
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
@ -252,7 +252,7 @@ class VectorService:
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
attachments = segment.attachments

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -36,7 +37,7 @@ def add_annotation_to_index_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
collection_binding_id=dataset_collection_binding.id,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import exists, select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
collection_binding_id=app_annotation_setting.collection_binding_id,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -64,7 +65,7 @@ def enable_annotation_reply_task(
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
@ -93,7 +94,7 @@ def enable_annotation_reply_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -37,7 +38,7 @@ def update_annotation_to_index_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -11,7 +11,7 @@ from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -120,7 +120,7 @@ def batch_create_segment_to_index_task(
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],

View File

@ -10,7 +10,7 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -127,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
logger.warning("Dataset %s not found after indexing", dataset_id)
return
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_index_setting = dataset.summary_index_setting
if summary_index_setting and summary_index_setting.get("enable"):
# expire all session to get latest document's indexing status

View File

@ -7,6 +7,7 @@ import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids:
return
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
click.style(
f"Skipping summary generation for dataset {dataset_id}: "

View File

@ -9,7 +9,7 @@ from celery import shared_task
from sqlalchemy import or_, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -53,7 +53,7 @@ def regenerate_summary_index_task(
return
# Only regenerate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
click.style(
f"Skipping summary regeneration for dataset {dataset_id}: "

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
@ -39,7 +39,7 @@ class TestGetAvailableDatasetsIntegration:
provider="dify",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -460,7 +460,7 @@ class TestKnowledgeRetrievalIntegration:
provider="dify",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
db_session_with_containers.add(dataset)

View File

@ -13,6 +13,7 @@ import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.enums import DataSourceType
@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory:
name=name,
description="Test description",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -1245,3 +1245,51 @@ class TestAppService:
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)
def test_get_app_code_by_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_code_by_id raises ValueError when site is missing."""
from uuid import uuid4
from services.app_service import AppService
with pytest.raises(ValueError, match="not found"):
AppService.get_app_code_by_id(str(uuid4()))
def test_get_app_id_by_code_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_id_by_code raises ValueError when code does not exist."""
from services.app_service import AppService
with pytest.raises(ValueError, match="not found"):
AppService.get_app_id_by_code("nonexistent-code")
def test_get_app_meta_returns_empty_when_workflow_missing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_meta returns empty tool_icons when workflow is None."""
from types import SimpleNamespace
from services.app_service import AppService
app_service = AppService()
workflow_app = SimpleNamespace(mode="workflow", workflow=None)
meta = app_service.get_app_meta(workflow_app)
assert meta == {"tool_icons": {}}
def test_get_app_meta_returns_empty_when_model_config_missing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_meta returns empty tool_icons when app_model_config is None."""
from types import SimpleNamespace
from services.app_service import AppService
app_service = AppService()
chat_app = SimpleNamespace(mode="chat", app_model_config=None)
meta = app_service.get_app_meta(chat_app)
assert meta == {"tool_icons": {}}

View File

@ -9,6 +9,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory:
name=name,
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -11,7 +11,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -63,7 +63,7 @@ class DatasetServiceIntegrationDataFactory:
name: str = "Test Dataset",
description: str | None = "Test description",
provider: str = "vendor",
indexing_technique: str | None = "high_quality",
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
permission: str = DatasetPermissionEnum.ONLY_ME,
retrieval_model: dict | None = None,
embedding_model_provider: str | None = None,
@ -157,13 +157,13 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Economy Dataset",
description=None,
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
account=account,
)
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "economy"
assert result.indexing_technique == IndexTechniqueType.ECONOMY
assert result.embedding_model_provider is None
assert result.embedding_model is None
@ -181,13 +181,13 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="High Quality Dataset",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
)
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model_name
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
@ -273,7 +273,7 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Dataset With Reranking",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
retrieval_model=retrieval_model,
)
@ -306,7 +306,7 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Custom Embedding Dataset",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
embedding_model_provider=embedding_provider,
embedding_model_name=embedding_model_name,
@ -314,7 +314,7 @@ class TestDatasetServiceCreateDataset:
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert result.embedding_model_provider == embedding_provider
assert result.embedding_model == embedding_model_name
mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name)
@ -589,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure="text_model",
)
DatasetServiceIntegrationDataFactory.create_document(
@ -685,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration:
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=str(uuid4()),
)
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": {
"search_method": "full_text_search",
"top_k": 10,

View File

@ -3,7 +3,7 @@
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -109,7 +109,7 @@ class TestDatasetServiceDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure=None,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid4()),
@ -208,7 +208,7 @@ class TestDatasetServiceDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure=None,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid4()),

View File

@ -12,6 +12,7 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory:
name=f"Test Dataset {uuid4()}",
description="Test description",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=DatasetPermissionEnum.ONLY_ME,
provider="vendor",

View File

@ -15,6 +15,7 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory:
name=name,
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -4,6 +4,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory:
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
retrieval_model: str = "old_model",
permission: str = "only_me",
embedding_model_provider: str | None = None,
@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset:
assert dataset.name == "new_name"
assert dataset.description == "new_description"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.retrieval_model == "new_model"
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"description": None,
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
"embedding_model_provider": None,
"embedding_model": None,
@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"indexing_technique": "economy",
"indexing_technique": IndexTechniqueType.ECONOMY,
"retrieval_model": "new_model",
}
@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset:
mock_task.delay.assert_called_once_with(dataset.id, "remove")
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "economy"
assert dataset.indexing_technique == IndexTechniqueType.ECONOMY
assert dataset.embedding_model is None
assert dataset.embedding_model_provider is None
assert dataset.collection_binding_id is None
@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
embedding_model = Mock()
@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset:
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset:
mock_task.delay.assert_called_once_with(dataset.id, "add")
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id
@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
}
@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset:
db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.collection_binding_id == existing_binding_id
@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset:
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType, TagType
@ -102,7 +103,7 @@ class TestTagService:
provider="vendor",
permission="only_me",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
tenant_id=tenant_id,
created_by=mock_external_service_dependencies["current_user"].id,
)

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import json
import uuid
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@ -8,14 +11,14 @@ from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
from services.workflow_app_service import WorkflowAppService
from services.workflow_app_service import LogView, WorkflowAppService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -1525,3 +1528,168 @@ class TestWorkflowAppService:
# Should not find tenant2's data when searching from tenant1's context
assert result_cross_tenant["total"] == 0
def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"):
service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
)
def test_get_paginate_workflow_app_logs_filters_by_account(
self, db_session_with_containers, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account)
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account=account.email,
)
assert result["total"] >= 0
assert isinstance(result["data"], list)
def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
end_user = EndUser(
tenant_id=app.tenant_id,
app_id=app.id,
type="browser",
is_anonymous=False,
session_id="session-1",
)
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
now = datetime.now(UTC)
archive_defaults = {
"workflow_id": str(uuid.uuid4()),
"run_version": "1.0.0",
"run_status": WorkflowExecutionStatus.SUCCEEDED,
"run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
"run_error": None,
"run_elapsed_time": 1.0,
"run_total_tokens": 0,
"run_total_steps": 0,
"run_created_at": now,
"run_finished_at": now,
"run_exceptions_count": 0,
"trigger_metadata": '{"type":"trigger-webhook"}',
"log_created_at": now,
"log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API,
}
archive_account = WorkflowArchiveLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_run_id=str(uuid.uuid4()),
log_id=str(uuid.uuid4()),
created_by=account.id,
created_by_role=CreatorUserRole.ACCOUNT,
**archive_defaults,
)
archive_end_user = WorkflowArchiveLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_run_id=str(uuid.uuid4()),
log_id=str(uuid.uuid4()),
created_by=end_user.id,
created_by_role=CreatorUserRole.END_USER,
**archive_defaults,
)
db_session_with_containers.add_all([archive_account, archive_end_user])
db_session_with_containers.commit()
result = service.get_paginate_workflow_archive_logs(
session=db_session_with_containers,
app_model=app,
page=1,
limit=20,
)
assert result["total"] == 2
assert len(result["data"]) == 2
account_item = next(d for d in result["data"] if d["created_by_account"] is not None)
end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None)
assert account_item["created_by_account"].id == account.id
assert end_user_item["created_by_end_user"].id == end_user.id
class TestLogView:
def test_details_and_proxy_attributes(self):
log = SimpleNamespace(id="log-1", status="succeeded")
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
assert view.details == {"trigger_metadata": {"type": "plugin"}}
assert view.status == "succeeded"
class TestHandleTriggerMetadata:
def test_returns_empty_dict_when_metadata_missing(self):
service = WorkflowAppService()
assert service.handle_trigger_metadata("tenant-1", None) == {}
def test_enriches_plugin_icons(self):
service = WorkflowAppService()
meta = {
"type": AppTriggerType.TRIGGER_PLUGIN.value,
"icon_filename": "light.png",
"icon_dark_filename": "dark.png",
}
with patch(
"services.workflow_app_service.PluginService.get_plugin_icon_url",
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
) as mock_icon:
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
assert result["icon"] == "https://cdn/light.png"
assert result["icon_dark"] == "https://cdn/dark.png"
assert mock_icon.call_count == 2
def test_non_plugin_metadata_without_icon_lookup(self):
service = WorkflowAppService()
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon:
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
mock_icon.assert_not_called()
class TestSafeJsonLoads:
@pytest.mark.parametrize(
("value", "expected"),
[
(None, None),
("", None),
('{"k":"v"}', {"k": "v"}),
("not-json", None),
({"raw": True}, {"raw": True}),
],
)
def test_handles_various_inputs(self, value, expected):
assert WorkflowAppService._safe_json_loads(value) == expected
class TestSafeParseUuid:
def test_returns_none_for_short_or_invalid_values(self):
service = WorkflowAppService()
assert service._safe_parse_uuid("short") is None
assert service._safe_parse_uuid("x" * 40) is None
def test_returns_uuid_for_valid_string(self):
service = WorkflowAppService()
raw = str(uuid.uuid4())
result = service._safe_parse_uuid(raw)
assert result is not None
assert str(result) == raw

View File

@ -1,12 +1,24 @@
from __future__ import annotations
from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolDescription,
ToolEntity,
ToolIdentity,
ToolParameter,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -786,3 +798,192 @@ class TestToolTransformService:
assert result is not None
assert result == mock_controller
mock_from_db.assert_called_once_with(provider)
def _mock_tool(*, base_params, runtime_params):
"""Helper to build a Mock tool with real entity objects.
Tool is abstract and requires runtime behaviour (fork_tool_runtime,
get_runtime_parameters), so it stays as a Mock. Everything else uses
real Pydantic instances.
"""
entity = ToolEntity(
identity=ToolIdentity(
author="test_author",
name="test_tool",
label=I18nObject(en_US="Test Tool"),
provider="test_provider",
),
parameters=base_params or [],
description=ToolDescription(
human=I18nObject(en_US="Test description"),
llm="Test description for LLM",
),
output_schema={},
)
mock_tool = Mock(spec=Tool)
mock_tool.entity = entity
mock_tool.get_runtime_parameters.return_value = runtime_params
mock_tool.fork_tool_runtime.return_value = mock_tool
return mock_tool
def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None):
return ToolParameter(
name=name,
label=I18nObject(en_US=label or name),
human_description=I18nObject(en_US=name),
type=ToolParameter.ToolParameterType.STRING,
form=form,
)
class TestConvertToolEntityToApiEntity:
"""Tests for ToolTransformService.convert_tool_entity_to_api_entity."""
def test_parameter_override(self):
base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")]
runtime = [_param("param1", label="Runtime 1")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 2
assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1"
assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2"
def test_additional_runtime_parameters(self):
base = [_param("param1", label="Base 1")]
runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert len(result.parameters) == 2
names = [p.name for p in result.parameters]
assert "param1" in names
assert "runtime_only" in names
def test_non_form_runtime_parameters_excluded(self):
base = [_param("param1")]
runtime = [
_param("param1", label="Runtime 1"),
_param("llm_param", form=ToolParameter.ToolParameterForm.LLM),
]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert len(result.parameters) == 1
assert result.parameters[0].name == "param1"
def test_empty_parameters(self):
tool = _mock_tool(base_params=[], runtime_params=[])
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 0
def test_none_parameters(self):
tool = _mock_tool(base_params=None, runtime_params=[])
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 0
def test_parameter_order_preserved(self):
base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")]
runtime = [_param("p2", label="R2"), _param("p4", label="R4")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"]
assert result.parameters[1].label.en_US == "R2"
class TestWorkflowProviderToUserProvider:
"""Tests for ToolTransformService.workflow_provider_to_user_provider."""
@staticmethod
def _make_controller(provider_id="provider_123", **identity_overrides):
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
defaults = {
"author": "test_author",
"name": "test_workflow_tool",
"description": I18nObject(en_US="Test description"),
"icon": '{"type": "emoji", "content": "🔧"}',
"icon_dark": None,
"label": I18nObject(en_US="Test Workflow Tool"),
}
defaults.update(identity_overrides)
identity = ToolProviderIdentity(**defaults)
entity = ToolProviderEntity(identity=identity)
return WorkflowToolProviderController(entity=entity, provider_id=provider_id)
def test_with_workflow_app_id(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["l1", "l2"],
workflow_app_id="app_123",
)
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "provider_123"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == "app_123"
assert result.labels == ["l1", "l2"]
assert result.is_team_authorization is True
def test_without_workflow_app_id(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["l1"],
)
assert result.workflow_app_id is None
def test_workflow_app_id_none_explicit(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=None,
workflow_app_id=None,
)
assert result.workflow_app_id is None
assert result.labels == []
def test_preserves_other_fields(self):
ctrl = self._make_controller(
"provider_456",
author="another_author",
name="another_workflow_tool",
description=I18nObject(en_US="Another desc", zh_Hans="Another desc"),
icon='{"type": "emoji", "content": "⚙️"}',
icon_dark='{"type": "emoji", "content": "🔧"}',
label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"),
)
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["automation"],
workflow_app_id="app_456",
)
assert result.id == "provider_456"
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == "app_456"
assert result.is_team_authorization is True
assert result.allow_delete is True

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -19,7 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -142,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask:
name=fake.company(),
description=fake.text(),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
created_by=account.id,

View File

@ -18,7 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -154,7 +154,7 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
name="test_dataset",
description="Test dataset for cleanup testing",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid.uuid4()),
created_by=account.id,
@ -870,7 +870,7 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
name=long_name,
description=long_description,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph", "max_length": 10000}',
collection_binding_id=str(uuid.uuid4()),
created_by=account.id,

View File

@ -12,7 +12,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -121,7 +121,7 @@ class TestCreateSegmentToIndexTask:
description=fake.text(max_nb_chars=100),
tenant_id=tenant_id,
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
created_by=account_id,

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask:
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = DataSourceType.UPLOAD_FILE
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.index_struct = '{"type": "paragraph"}'
dataset.created_by = account.id
dataset.created_at = fake.date_time_this_year()

View File

@ -15,7 +15,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -100,7 +100,7 @@ class TestDisableSegmentFromIndexTask:
name=fake.sentence(nb_words=3),
description=fake.text(max_nb_chars=200),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
@ -103,7 +103,7 @@ class TestDisableSegmentsFromIndexTask:
provider="vendor",
permission="only_me",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
updated_by=account.id,
embedding_model="text-embedding-ada-002",

View File

@ -14,7 +14,7 @@ from uuid import uuid4
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -57,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
name=f"dataset-{uuid4()}",
description="sync test dataset",
data_source_type=DataSourceType.NOTION_IMPORT,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
)
db_session_with_containers.add(dataset)

View File

@ -5,6 +5,7 @@ import pytest
from faker import Faker
from core.entities.document_task import DocumentTask
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -99,7 +100,7 @@ class TestDocumentIndexingTasks:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -181,7 +182,7 @@ class TestDocumentIndexingTasks:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -64,7 +64,7 @@ class TestDocumentIndexingUpdateTask:
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -110,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -245,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
):
response, status = method(api, "dataset-1")
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
document.error = None
document.stopped_at = None
# First count = completed segments, second = total segments
query_mock = MagicMock()
query_mock.where.side_effect = [
MagicMock(count=lambda: 2),
MagicMock(count=lambda: 5),
]
with (
app.test_request_context("/"),
patch(
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=query_mock,
"controllers.console.datasets.datasets.db.session.scalar",
side_effect=[2, 5],
),
):
response, status = method(api, "dataset-1")
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=10,
),
):
with pytest.raises(BadRequest) as exc_info:
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=mock_key,
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):

View File

@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(ValueError):
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(NotFound):
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",

View File

@ -4,6 +4,7 @@ from unittest.mock import Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -21,7 +22,7 @@ class TestParagraphIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset
@ -167,7 +168,7 @@ class TestParagraphIndexProcessor:
def test_load_uses_keyword_add_texts_with_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
@ -178,7 +179,7 @@ class TestParagraphIndexProcessor:
def test_load_uses_keyword_add_texts_without_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
@ -208,7 +209,7 @@ class TestParagraphIndexProcessor:
def test_clean_economy_deletes_summaries_and_keywords(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with (
patch(
@ -222,7 +223,7 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.delete.assert_called_once()
def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
processor.clean(dataset, ["node-2"], with_keywords=True)
@ -267,7 +268,7 @@ class TestParagraphIndexProcessor:
def test_index_list_chunks_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@ -19,7 +20,7 @@ class TestParentChildIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset

View File

@ -6,6 +6,7 @@ import pytest
from werkzeug.datastructures import FileStorage
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
@ -33,7 +34,7 @@ class TestQAIndexProcessor:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.is_multimodal = True
return dataset
@ -207,7 +208,7 @@ class TestQAIndexProcessor:
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
@ -298,7 +299,7 @@ class TestQAIndexProcessor:
def test_index_requires_high_quality(
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
dataset.indexing_technique = IndexTechniqueType.ECONOMY
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
with (

View File

@ -61,7 +61,7 @@ from core.indexing_runner import (
DocumentIsPausedError,
IndexingRunner,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import ChildDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument
def create_mock_dataset(
dataset_id: str | None = None,
tenant_id: str | None = None,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
) -> Mock:
@ -458,7 +458,7 @@ class TestIndexingRunnerTransform:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid.uuid4())
dataset.tenant_id = str(uuid.uuid4())
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@ -521,7 +521,7 @@ class TestIndexingRunnerTransform:
"""Test transformation with economy indexing (no embeddings)."""
# Arrange
runner = IndexingRunner()
sample_dataset.indexing_technique = "economy"
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_processor = MagicMock()
transformed_docs = [
@ -605,7 +605,7 @@ class TestIndexingRunnerLoad:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid.uuid4())
dataset.tenant_id = str(uuid.uuid4())
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@ -674,7 +674,7 @@ class TestIndexingRunnerLoad:
"""Test loading with economy indexing (keyword only)."""
# Arrange
runner = IndexingRunner()
sample_dataset.indexing_technique = "economy"
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_processor = MagicMock()
@ -701,7 +701,7 @@ class TestIndexingRunnerLoad:
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
# Add child documents
for doc in sample_documents:
@ -795,7 +795,7 @@ class TestIndexingRunnerRun:
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = doc.dataset_id
mock_dataset.tenant_id = doc.tenant_id
mock_dataset.indexing_technique = "economy"
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_process_rule = Mock(spec=DatasetProcessRule)
@ -949,7 +949,7 @@ class TestIndexingRunnerRun:
mock_dependencies["db"].session.get.side_effect = get_side_effect
mock_dataset = Mock(spec=Dataset)
mock_dataset.indexing_technique = "economy"
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_process_rule = Mock(spec=DatasetProcessRule)

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
@ -78,7 +79,7 @@ def sample_node_data():
type="knowledge-index",
chunk_structure="general_structure",
index_chunk_variable_selector=["start", "chunks"],
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
summary_index_setting=None,
)

View File

@ -15,6 +15,7 @@ from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import (
AppDatasetJoin,
ChildChunk,
@ -67,14 +68,14 @@ class TestDatasetModelValidation:
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
description="Test description",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
)
# Assert
assert dataset.description == "Test description"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
@ -86,21 +87,21 @@ class TestDatasetModelValidation:
name="High Quality Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
dataset_economy = Dataset(
tenant_id=str(uuid4()),
name="Economy Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
# Assert
assert dataset_high_quality.indexing_technique == "high_quality"
assert dataset_economy.indexing_technique == "economy"
assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST
assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST
assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY
assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST
assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST
def test_dataset_provider_validation(self):
"""Test dataset provider values."""
@ -983,7 +984,7 @@ class TestModelIntegration:
name="Test Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
dataset.id = dataset_id
@ -1019,7 +1020,7 @@ class TestModelIntegration:
assert document.dataset_id == dataset_id
assert segment.dataset_id == dataset_id
assert segment.document_id == document_id
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert document.word_count == 100
assert segment.status == SegmentStatus.COMPLETED

View File

@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory:
name: str = "Test Dataset",
description: str = "Test description",
tenant_id: str = "tenant-123",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str | None = "openai",
embedding_model: str | None = "text-embedding-ada-002",
collection_binding_id: str | None = "binding-123",
@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory:
@staticmethod
def create_knowledge_configuration_mock(
chunk_structure: str = "tree",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
keyword_number: int = 10,
@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
# Assert
assert dataset.chunk_structure == "list"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == "binding-123"
@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree", # Existing structure
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list", # Different structure
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
mock_session.merge.return_value = dataset
@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
indexing_technique="high_quality", # Current technique
indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
indexing_technique="economy", # Trying to change to economy
indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy
)
mock_session.merge.return_value = dataset

View File

@ -111,7 +111,7 @@ from unittest.mock import Mock, patch
import pytest
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, DatasetProcessRule, Document
from services.dataset_service import DatasetService, DocumentService
@ -154,7 +154,7 @@ class DocumentValidationTestDataFactory:
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str | None = None,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
**kwargs,
@ -190,7 +190,7 @@ class DocumentValidationTestDataFactory:
data_source: DataSource | None = None,
process_rule: ProcessRule | None = None,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
**kwargs,
) -> Mock:
"""
@ -448,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
@ -481,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
- No errors are raised
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
# Act (should not raise)
DatasetService.check_dataset_model_setting(dataset)
@ -503,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="invalid-model",
)
@ -533,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import SegmentType
@ -111,7 +111,7 @@ class SegmentTestDataFactory:
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model: str = "text-embedding-ada-002",
embedding_model_provider: str = "openai",
**kwargs,
@ -163,7 +163,7 @@ class TestSegmentServiceCreateSegment:
"""Test successful creation of a segment."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "New segment content", "keywords": ["test", "segment"]}
mock_query = MagicMock()
@ -212,7 +212,7 @@ class TestSegmentServiceCreateSegment:
"""Test creation of segment with QA model (requires answer)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
mock_query = MagicMock()
@ -247,7 +247,7 @@ class TestSegmentServiceCreateSegment:
"""Test creation of segment with high quality indexing technique."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
@ -289,7 +289,7 @@ class TestSegmentServiceCreateSegment:
"""Test segment creation when vector indexing fails."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
@ -342,7 +342,7 @@ class TestSegmentServiceUpdateSegment:
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = SegmentUpdateArgs(content="Updated content", keywords=["updated"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment
@ -431,7 +431,7 @@ class TestSegmentServiceUpdateSegment:
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment

View File

@ -1,214 +0,0 @@
"""
Unit tests for services.advanced_prompt_template_service
"""
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Test suite for AdvancedPromptTemplateService."""
def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None:
"""Test baichuan model names use baichuan context prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "Baichuan2-13B",
"has_context": "true",
}
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None:
"""Test non-baichuan model names use common prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-4",
"has_context": "false",
}
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result == original_config
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None:
"""Test invalid app mode returns empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "chat"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}
def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None:
"""Test context is prepended for completion prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT)
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None:
"""Test context is prepended for chat prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT)
assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None:
"""Test chat prompt remains unchanged when has_context is false."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false")
# Assert
assert result == original_config
assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None:
"""Test completion app mode with completion model returns completion prompt."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false")
# Assert
assert result == original_config
assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None:
"""Test invalid model mode returns empty dict."""
# Arrange
app_mode = AppMode.CHAT
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false")
# Assert
assert result == {}
def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps completion prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["completion_prompt_config"]["prompt"]["text"] == original_text
def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps chat prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text
def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None:
"""Test baichuan chat/completion returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None:
"""Test baichuan completion/chat returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None:
"""Test baichuan completion/completion prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None:
"""Test baichuan chat/chat prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None:
"""Test invalid baichuan mode combinations return empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}

View File

@ -1,683 +0,0 @@
"""Unit tests for services.app_service."""
import json
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
from core.errors.error import ProviderTokenNotInitError
from models import Account, Tenant
from models.model import App, AppMode, IconType
from services.app_service import AppService
@pytest.fixture
def service() -> AppService:
"""Provide AppService instance."""
return AppService()
@pytest.fixture
def account() -> Account:
"""Create account object for create_app tests."""
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
result = Account(name="Account User", email="account@example.com")
result.id = "acc-1"
result._current_tenant = tenant
return result
@pytest.fixture
def default_args() -> dict:
"""Create default create_app args."""
return {
"name": "Test App",
"mode": AppMode.CHAT.value,
"icon": "🤖",
"icon_background": "#FFFFFF",
}
@pytest.fixture
def app_template() -> dict:
"""Create basic app template for create_app tests."""
return {
AppMode.CHAT: {
"app": {},
"model_config": {
"model": {
"provider": "provider-a",
"name": "model-a",
"mode": "chat",
"completion_params": {},
}
},
}
}
def _make_current_user() -> Account:
user = Account(name="Tester", email="tester@example.com")
user.id = "user-1"
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
user._current_tenant = tenant
return user
class TestAppServicePagination:
"""Test suite for get_paginate_apps."""
def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None:
"""Test pagination returns None when tag filter has no targets."""
# Arrange
args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]}
with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]):
# Act
result = service.get_paginate_apps("user-1", "tenant-1", args)
# Assert
assert result is None
def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None:
"""Test pagination delegates to db.paginate when filters are valid."""
# Arrange
args = {
"mode": "workflow",
"is_created_by_me": True,
"name": "My_App%",
"tag_ids": ["tag-1"],
"page": 2,
"limit": 10,
}
expected_pagination = MagicMock()
with (
patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]),
patch("libs.helper.escape_like_pattern", return_value="escaped"),
patch("services.app_service.db") as mock_db,
):
mock_db.paginate.return_value = expected_pagination
# Act
result = service.get_paginate_apps("user-1", "tenant-1", args)
# Assert
assert result is expected_pagination
mock_db.paginate.assert_called_once()
class TestAppServiceCreate:
"""Test suite for create_app."""
def test_create_app_should_create_with_matching_default_model(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app uses matching default model and persists app config."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
model_instance = SimpleNamespace(
model_name="model-a",
provider="provider-a",
model_type_instance=MagicMock(),
credentials={"k": "v"},
)
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
patch("services.app_service.app_was_created") as mock_event,
patch("services.app_service.FeatureService.get_system_features") as mock_features,
patch("services.app_service.BillingService") as mock_billing,
patch("services.app_service.dify_config") as mock_config,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.return_value = model_instance
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_config.BILLING_ENABLED = True
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
assert app_instance.app_model_config_id == "cfg-1"
mock_db.session.add.assert_any_call(app_instance)
mock_db.session.add.assert_any_call(app_model_config)
assert mock_db.session.flush.call_count == 2
mock_db.session.commit.assert_called_once()
mock_event.send.assert_called_once_with(app_instance, account=account)
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
def test_create_app_should_raise_when_model_schema_missing(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app raises ValueError when non-matching model has no schema."""
# Arrange
app_instance = SimpleNamespace(id="app-1")
model_instance = SimpleNamespace(
model_name="model-b",
provider="provider-b",
model_type_instance=MagicMock(),
credentials={"k": "v"},
)
model_instance.model_type_instance.get_model_schema.return_value = None
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.return_value = model_instance
# Act & Assert
with pytest.raises(ValueError, match="model schema not found"):
service.create_app("tenant-1", default_args, account)
mock_db.session.commit.assert_not_called()
def test_create_app_should_fallback_to_default_provider_when_model_missing(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app falls back to provider/model name when no default model instance is available."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
patch("services.app_service.app_was_created") as mock_event,
patch("services.app_service.FeatureService.get_system_features") as mock_features,
patch("services.app_service.EnterpriseService") as mock_enterprise,
patch("services.app_service.dify_config") as mock_config,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready")
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_config.BILLING_ENABLED = False
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
mock_event.send.assert_called_once_with(app_instance, account=account)
mock_db.session.commit.assert_called_once()
mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private")
def test_create_app_should_log_and_fallback_on_unexpected_model_error(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test unexpected model manager errors are logged and fallback provider is used."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db"),
patch("services.app_service.app_was_created"),
patch(
"services.app_service.FeatureService.get_system_features",
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
),
patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)),
patch("services.app_service.logger") as mock_logger,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.side_effect = RuntimeError("boom")
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
mock_logger.exception.assert_called_once()
class TestAppServiceGetAndUpdate:
"""Test suite for app retrieval and update methods."""
def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None:
"""Test get_app returns original app for non-agent modes."""
# Arrange
app = MagicMock()
app.mode = AppMode.CHAT
app.is_agent = False
with patch("services.app_service.current_user", _make_current_user()):
# Act
result = service.get_app(app)
# Assert
assert result is app
def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None:
"""Test get_app returns app when agent mode has no model config."""
# Arrange
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = None
with patch("services.app_service.current_user", _make_current_user()):
# Act
result = service.get_app(app)
# Assert
assert result is app
def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None:
"""Test get_app decrypts and masks secret tool parameters."""
# Arrange
tool = {
"provider_type": "builtin",
"provider_id": "provider-1",
"tool_name": "tool-a",
"tool_parameters": {"secret": "encrypted"},
"extra": True,
}
model_config = MagicMock()
model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]}
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = model_config
manager = MagicMock()
manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"}
manager.mask_tool_parameters.return_value = {"secret": "***"}
with (
patch("services.app_service.current_user", _make_current_user()),
patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()),
patch("services.app_service.ToolParameterConfigurationManager", return_value=manager),
):
# Act
result = service.get_app(app)
# Assert
assert result.app_model_config is model_config
assert tool["tool_parameters"] == {"secret": "***"}
assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"}
def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None:
"""Test get_app logs and continues when masking fails."""
# Arrange
tool = {
"provider_type": "builtin",
"provider_id": "provider-1",
"tool_name": "tool-a",
"tool_parameters": {"secret": "encrypted"},
"extra": True,
}
model_config = MagicMock()
model_config.agent_mode_dict = {"tools": [tool]}
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = model_config
with (
patch("services.app_service.current_user", _make_current_user()),
patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")),
patch("services.app_service.logger") as mock_logger,
):
# Act
result = service.get_app(app)
# Assert
assert result.app_model_config is model_config
mock_logger.exception.assert_called_once()
def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None:
"""Test update methods set fields and commit changes."""
# Arrange
app = cast(
App,
SimpleNamespace(
name="old",
description="old",
icon_type="emoji",
icon="a",
icon_background="#111",
enable_site=True,
enable_api=True,
),
)
args = {
"name": "new",
"description": "new-desc",
"icon_type": "image",
"icon": "new-icon",
"icon_background": "#222",
"use_icon_as_answer_icon": True,
"max_active_requests": 5,
}
user = SimpleNamespace(id="user-1")
with (
patch("services.app_service.current_user", user),
patch("services.app_service.db") as mock_db,
patch("services.app_service.naive_utc_now", return_value="now"),
):
# Act
updated = service.update_app(app, args)
renamed = service.update_app_name(app, "rename")
iconed = service.update_app_icon(app, "icon-2", "#333")
site_same = service.update_app_site_status(app, app.enable_site)
api_same = service.update_app_api_status(app, app.enable_api)
site_changed = service.update_app_site_status(app, False)
api_changed = service.update_app_api_status(app, False)
# Assert
assert updated is app
assert updated.icon_type == IconType.IMAGE
assert renamed is app
assert iconed is app
assert site_same is app
assert api_same is app
assert site_changed is app
assert api_changed is app
assert mock_db.session.commit.call_count >= 5
def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None:
"""Test update_app keeps the existing icon_type when the payload omits it."""
# Arrange
app = cast(
App,
SimpleNamespace(
name="old",
description="old",
icon_type=IconType.EMOJI,
icon="a",
icon_background="#111",
use_icon_as_answer_icon=False,
max_active_requests=1,
),
)
args = {
"name": "new",
"description": "new-desc",
"icon_type": None,
"icon": "new-icon",
"icon_background": "#222",
"use_icon_as_answer_icon": True,
"max_active_requests": 5,
}
user = SimpleNamespace(id="user-1")
with (
patch("services.app_service.current_user", user),
patch("services.app_service.db") as mock_db,
patch("services.app_service.naive_utc_now", return_value="now"),
):
# Act
updated = service.update_app(app, args)
# Assert
assert updated is app
assert updated.icon_type == IconType.EMOJI
mock_db.session.commit.assert_called_once()
def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None:
"""Test update_app rejects an explicit empty icon_type."""
app = cast(
App,
SimpleNamespace(
name="old",
description="old",
icon_type=IconType.EMOJI,
icon="a",
icon_background="#111",
use_icon_as_answer_icon=False,
max_active_requests=1,
),
)
args = {
"name": "new",
"description": "new-desc",
"icon_type": "",
"icon": "new-icon",
"icon_background": "#222",
"use_icon_as_answer_icon": True,
"max_active_requests": 5,
}
user = SimpleNamespace(id="user-1")
with (
patch("services.app_service.current_user", user),
patch("services.app_service.db") as mock_db,
):
with pytest.raises(ValueError):
service.update_app(app, args)
mock_db.session.commit.assert_not_called()
class TestAppServiceDeleteAndMeta:
"""Test suite for delete and metadata methods."""
def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None:
"""Test delete_app removes app, runs cleanup, and triggers async deletion task."""
# Arrange
app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
with (
patch("services.app_service.db") as mock_db,
patch(
"services.app_service.FeatureService.get_system_features",
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)),
),
patch("services.app_service.EnterpriseService") as mock_enterprise,
patch(
"services.app_service.dify_config",
new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"),
),
patch("services.app_service.BillingService") as mock_billing,
patch("services.app_service.remove_app_and_related_data_task") as mock_task,
):
# Act
service.delete_app(app)
# Assert
mock_db.session.delete.assert_called_once_with(app)
mock_db.session.commit.assert_called_once()
mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1")
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1")
def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None:
"""Test get_app_meta extracts builtin and API tool icons from workflow graph."""
# Arrange
workflow = SimpleNamespace(
graph_dict={
"nodes": [
{
"data": {
"type": "tool",
"provider_type": "builtin",
"provider_id": "builtin-provider",
"tool_name": "tool_builtin",
}
},
{
"data": {
"type": "tool",
"provider_type": "api",
"provider_id": "api-provider-id",
"tool_name": "tool_api",
}
},
]
}
)
app = cast(
App,
SimpleNamespace(
mode=AppMode.WORKFLOW.value,
workflow=workflow,
app_model_config=None,
tenant_id="tenant-1",
icon_type="emoji",
icon_background="#fff",
),
)
provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"}))
with (
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
patch("services.app_service.db") as mock_db,
):
query = MagicMock()
query.where.return_value = query
query.first.return_value = provider
mock_db.session.query.return_value = query
# Act
meta = service.get_app_meta(app)
# Assert
assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon")
assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"}
def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None:
"""Test get_app_meta falls back to default icon when API provider lookup fails."""
# Arrange
app_model_config = SimpleNamespace(
agent_mode_dict={
"tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}]
}
)
app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None))
with (
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
patch("services.app_service.db") as mock_db,
):
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act
meta = service.get_app_meta(app)
# Assert
assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"}
def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None:
"""Test get_app_meta returns empty metadata when workflow/model config is absent."""
# Arrange
workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None))
chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None))
# Act
workflow_meta = service.get_app_meta(workflow_app)
chat_meta = service.get_app_meta(chat_app)
# Assert
assert workflow_meta == {"tool_icons": {}}
assert chat_meta == {"tool_icons": {}}
class TestAppServiceCodeLookup:
"""Test suite for app code lookup methods."""
def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None:
"""Test get_app_code_by_id raises when site is missing."""
# Arrange
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError, match="not found"):
AppService.get_app_code_by_id("app-1")
def test_get_app_code_by_id_should_return_code(self) -> None:
"""Test get_app_code_by_id returns site code."""
# Arrange
site = SimpleNamespace(code="code-1")
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = site
mock_db.session.query.return_value = query
# Act
result = AppService.get_app_code_by_id("app-1")
# Assert
assert result == "code-1"
def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None:
"""Test get_app_id_by_code raises when code does not exist."""
# Arrange
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError, match="not found"):
AppService.get_app_id_by_code("missing")
def test_get_app_id_by_code_should_return_app_id(self) -> None:
"""Test get_app_id_by_code returns linked app id."""
# Arrange
site = SimpleNamespace(app_id="app-1")
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = site
mock_db.session.query.return_value = query
# Act
result = AppService.get_app_id_by_code("code-1")
# Assert
assert result == "app-1"

View File

@ -4,7 +4,7 @@ from unittest.mock import Mock, create_autospec
import pytest
from redis.exceptions import LockNotOwnedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService, SegmentService
@ -71,7 +71,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch
# Minimal knowledge_config stub that satisfies pre-lock code
info_list = types.SimpleNamespace(data_source_type="upload_file")
@ -80,7 +80,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
doc_form=IndexStructureType.QA_INDEX,
original_document_id=None, # go into "new document" branch
data_source=data_source,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model=None,
embedding_model_provider=None,
retrieval_model=None,
@ -126,7 +126,7 @@ def test_add_segment_ignores_lock_not_owned(
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch
document = create_autospec(Document, instance=True)
document.id = "doc-1"
@ -169,7 +169,7 @@ def test_multi_create_segment_ignores_lock_not_owned(
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # again, skip high_quality path
dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path
document = create_autospec(Document, instance=True)
document.id = "doc-1"

View File

@ -11,7 +11,7 @@ from unittest.mock import MagicMock
import pytest
import services.summary_index_service as summary_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.enums import SegmentStatus, SummaryStatus
from services.summary_index_service import SummaryIndexService
@ -27,7 +27,7 @@ class _SessionContext:
return None
def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock:
def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock:
dataset = MagicMock(name="dataset")
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
@ -169,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N
def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
vector_cls = MagicMock()
monkeypatch.setattr(summary_module, "Vector", vector_cls)
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy"))
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset)
vector_cls.assert_not_called()
@ -621,7 +622,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset(indexing_technique="economy")
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
@ -778,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo
def test_enable_summaries_for_segments_skips_non_high_quality() -> None:
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy"))
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY))
def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None:
@ -932,9 +933,8 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon
def test_update_summary_for_segment_skip_conditions() -> None:
assert (
SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None
)
economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None
seg = _segment(has_document=True)
seg.document.doc_form = IndexStructureType.QA_INDEX
assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None

View File

@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
import services.vector_service as vector_service_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from services.vector_service import VectorService
@ -32,7 +32,7 @@ class _ParentDocStub:
def _make_dataset(
*,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
tenant_id: str = "tenant-1",
dataset_id: str = "dataset-1",
@ -192,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
embedding_model_provider="openai",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
segment = _make_segment()
@ -241,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
embedding_model_provider=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
segment = _make_segment()
@ -329,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk
def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
segment = _make_segment()
dataset_document = MagicMock()
@ -348,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch
def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
segment = _make_segment()
vector_instance = MagicMock()
@ -364,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk
def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
segment = _make_segment()
keyword_instance = MagicMock()
@ -380,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat
def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
segment = _make_segment()
keyword_instance = MagicMock()
@ -473,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest
def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = MagicMock()
child_chunk.content = "child"
child_chunk.index_node_id = "id"
@ -489,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M
def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
vector_cls = MagicMock()
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
@ -505,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch)
def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = MagicMock()
new_chunk.content = "n"
@ -536,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte
def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy")
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
vector_cls = MagicMock()
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
VectorService.update_child_chunk_vector([], [], [], dataset)
@ -561,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch
def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="economy", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True)
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}])
vector_cls = MagicMock()
@ -575,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt
def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}])
vector_cls = MagicMock()
@ -591,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt
def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}])
vector_instance = MagicMock(name="vector_instance")
@ -612,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -630,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch
def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -663,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False)
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()
@ -683,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops
def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
vector_instance = MagicMock()

View File

@ -1,379 +0,0 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound, Unauthorized
from models import Account, AccountStatus
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.webapp_auth_service.db")
mocked_db.session = MagicMock()
return mocked_db
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act + Assert
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(AccountLoginError, match="Account is banned"):
WebAppAuthService.authenticate("user@example.com", "pwd")
@pytest.mark.parametrize("password_value", [None, "hash"])
def test_authenticate_should_raise_password_error_when_password_is_invalid(
password_value: str | None,
mocker: MockerFixture,
) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
# Act + Assert
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
# Act
result = WebAppAuthService.authenticate("user@example.com", "pwd")
# Assert
assert result is account
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="u@example.com")
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
# Act
result = WebAppAuthService.login(account)
# Assert
assert result == "jwt-token"
mock_get_token.assert_called_once_with(account=account)
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act
result = WebAppAuthService.get_user_through_email("missing@example.com")
# Assert
assert result is None
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(Unauthorized, match="Account is banned"):
WebAppAuthService.get_user_through_email("user@example.com")
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act
result = WebAppAuthService.get_user_through_email("user@example.com")
# Assert
assert result is account
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Email must be provided"):
WebAppAuthService.send_email_code_login_email(account=None, email=None)
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
mocker: MockerFixture,
) -> None:
# Arrange
account = _account(email="user@example.com")
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
# Assert
assert result == "token-1"
mock_generate_token.assert_called_once()
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
# Assert
assert result == "token-2"
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
# Act
result = WebAppAuthService.get_email_code_login_data("token-abc")
# Assert
assert result == {"code": "123"}
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
# Act
WebAppAuthService.revoke_email_code_login_token("token-xyz")
# Assert
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(NotFound, match="Site not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_query = MagicMock()
app_query.where.return_value.first.return_value = None
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
# Act + Assert
with pytest.raises(NotFound, match="App not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
# Act
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
# Assert
assert result.tenant_id == "tenant-1"
assert result.app_id == "app-1"
assert result.session_id == "user@example.com"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="user@example.com")
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
# Act
token = WebAppAuthService._get_account_jwt_token(account)
# Assert
assert token == "jwt-1"
payload = mock_issue.call_args.args[0]
assert payload["user_id"] == "a1"
assert payload["session_id"] == "user@example.com"
assert payload["token_source"] == "webapp_login_token"
assert payload["auth_type"] == "internal"
assert payload["exp"] > int(datetime.now(UTC).timestamp())
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("private", True),
("private_all", True),
("public", False),
],
)
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
access_mode: str,
expected: bool,
) -> None:
# Arrange
# Act
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
# Assert
assert result is expected
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
WebAppAuthService.is_app_require_permission_check()
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
# Act + Assert
with pytest.raises(ValueError, match="App ID could not be determined"):
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
# Assert
assert result is True
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="public"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
# Assert
assert result is False
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("public", WebAppAuthType.PUBLIC),
("private", WebAppAuthType.INTERNAL),
("private_all", WebAppAuthType.INTERNAL),
("sso_verified", WebAppAuthType.EXTERNAL),
],
)
def test_get_app_auth_type_should_map_access_modes_correctly(
access_mode: str,
expected: WebAppAuthType,
) -> None:
# Arrange
# Act
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
# Assert
assert result == expected
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private_all"),
)
# Act
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
# Assert
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
WebAppAuthService.get_app_auth_type()
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Could not determine app authentication type"):
WebAppAuthService.get_app_auth_type(access_mode="unknown")

View File

@ -1,300 +0,0 @@
from __future__ import annotations
import json
import uuid
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from dify_graph.enums import WorkflowExecutionStatus
from models import App, WorkflowAppLog
from models.enums import AppTriggerType, CreatorUserRole
from services.workflow_app_service import LogView, WorkflowAppService
@pytest.fixture
def service() -> WorkflowAppService:
# Arrange
return WorkflowAppService()
@pytest.fixture
def app_model() -> App:
# Arrange
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
# Arrange
log = _workflow_app_log(id="log-1", status="succeeded")
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
# Act
details = view.details
proxied_status = view.status
# Assert
assert details == {"trigger_metadata": {"type": "plugin"}}
assert proxied_status == "succeeded"
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_1 = SimpleNamespace(id="log-1")
log_2 = SimpleNamespace(id="log-2")
session.scalar.return_value = 3
session.scalars.return_value.all.return_value = [log_1, log_2]
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
page=1,
limit=2,
detail=False,
)
# Assert
assert result["page"] == 1
assert result["limit"] == 2
assert result["total"] == 3
assert result["has_more"] is True
assert len(result["data"]) == 2
assert isinstance(result["data"][0], LogView)
assert result["data"][0].details is None
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
service: WorkflowAppService,
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [1]
log_1 = SimpleNamespace(id="log-1")
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
mock_handle = mocker.patch.object(
service,
"handle_trigger_metadata",
return_value={"type": "trigger_plugin", "icon": "url"},
)
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword="run-1",
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_before=None,
created_at_after=None,
page=1,
limit=20,
detail=True,
)
# Assert
assert result["total"] == 1
assert len(result["data"]) == 1
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
mock_handle.assert_called_once()
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Account not found: account@example.com"):
service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
session.scalars.return_value.all.return_value = []
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
# Assert
assert result["total"] == 0
assert result["data"] == []
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_account = SimpleNamespace(
id="log-1",
created_by="acc-1",
created_by_role=CreatorUserRole.ACCOUNT,
workflow_run_summary={"run": "1"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-01",
)
log_end_user = SimpleNamespace(
id="log-2",
created_by="end-1",
created_by_role=CreatorUserRole.END_USER,
workflow_run_summary={"run": "2"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-02",
)
log_unknown = SimpleNamespace(
id="log-3",
created_by="other",
created_by_role="system",
workflow_run_summary={"run": "3"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-03",
)
session.scalar.return_value = 3
session.scalars.side_effect = [
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
]
# Act
result = service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
page=1,
limit=20,
)
# Assert
assert result["total"] == 3
assert len(result["data"]) == 3
assert result["data"][0]["created_by_account"].id == "acc-1"
assert result["data"][1]["created_by_end_user"].id == "end-1"
assert result["data"][2]["created_by_account"] is None
assert result["data"][2]["created_by_end_user"] is None
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service.handle_trigger_metadata("tenant-1", None)
# Assert
assert result == {}
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {
"type": AppTriggerType.TRIGGER_PLUGIN.value,
"icon_filename": "light.png",
"icon_dark_filename": "dark.png",
}
mock_icon = mocker.patch(
"services.workflow_app_service.PluginService.get_plugin_icon_url",
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
)
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["icon"] == "https://cdn/light.png"
assert result["icon_dark"] == "https://cdn/dark.png"
assert mock_icon.call_count == 2
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
mock_icon.assert_not_called()
@pytest.mark.parametrize(
("value", "expected"),
[
(None, None),
("", None),
('{"k":"v"}', {"k": "v"}),
("not-json", None),
({"raw": True}, {"raw": True}),
],
)
def test_safe_json_loads_should_handle_various_inputs(
value: object,
expected: object,
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service._safe_json_loads(value)
# Assert
assert result == expected
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
# Arrange
# Act
short_result = service._safe_parse_uuid("short")
invalid_result = service._safe_parse_uuid("x" * 40)
# Assert
assert short_result is None
assert invalid_result is None
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
# Arrange
raw_uuid = str(uuid.uuid4())
# Act
result = service._safe_parse_uuid(raw_uuid)
# Assert
assert result is not None
assert str(result) == raw_uuid

View File

@ -1,452 +0,0 @@
from unittest.mock import Mock
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from services.tools.tools_transform_service import ToolTransformService
class TestToolTransformService:
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
def test_convert_tool_with_parameter_override(self):
"""Test that runtime parameters correctly override base parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
# Create mock runtime parameters that override base parameters
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1" # Different label to verify override
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.author == "test_author"
assert result.name == "test_tool"
assert result.parameters is not None
assert len(result.parameters) == 2
# Find the overridden parameter
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
# Find the non-overridden parameter
original_param = next((p for p in result.parameters if p.name == "param2"), None)
assert original_param is not None
assert original_param.label == "Base Param 2" # Should be base version
def test_convert_tool_with_additional_runtime_parameters(self):
"""Test that additional runtime parameters are added to the final list"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters - one that overrides and one that's new
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "runtime_only"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Only Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 2
# Check that both parameters are present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "runtime_only" in param_names
# Verify the overridden parameter has runtime version
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1"
# Verify the new runtime parameter is included
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
assert new_param is not None
assert new_param.label == "Runtime Only Param"
def test_convert_tool_with_non_form_runtime_parameters(self):
"""Test that non-FORM runtime parameters are not added as new parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters with different forms
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "llm_param"
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
runtime_param2.type = "string"
runtime_param2.label = "LLM Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 1 # Only the FORM parameter should be present
# Check that only the FORM parameter is present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "llm_param" not in param_names
def test_convert_tool_with_empty_parameters(self):
"""Test conversion with empty base and runtime parameters"""
# Create mock tool with no parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = []
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_with_none_parameters(self):
"""Test conversion when base parameters is None"""
# Create mock tool with None parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = None
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_parameter_order_preserved(self):
"""Test that parameter order is preserved correctly"""
# Create mock base parameters in specific order
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
base_param3 = Mock(spec=ToolParameter)
base_param3.name = "param3"
base_param3.form = ToolParameter.ToolParameterForm.FORM
base_param3.type = "string"
base_param3.label = "Base Param 3"
# Create runtime parameter that overrides middle parameter
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "param2"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Param 2"
# Create new runtime parameter
runtime_param4 = Mock(spec=ToolParameter)
runtime_param4.name = "param4"
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
runtime_param4.type = "string"
runtime_param4.label = "Runtime Param 4"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 4
# Check that order is maintained: base parameters first, then new runtime parameters
param_names = [p.name for p in result.parameters]
assert param_names == ["param1", "param2", "param3", "param4"]
# Verify that param2 was overridden with runtime version
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"
class TestWorkflowProviderToUserProvider:
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
workflow_app_id = "app_123"
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1", "label2"],
workflow_app_id=workflow_app_id,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "test_author"
assert result.name == "test_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["label1", "label2"]
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method without workflow_app_id
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1"],
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == ["label1"]
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method with explicit None values
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=None,
workflow_app_id=None,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == []
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller with various fields
workflow_app_id = "app_456"
provider_id = "provider_456"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "another_author"
mock_controller.entity.identity.name = "another_workflow_tool"
mock_controller.entity.identity.description = I18nObject(
en_US="Another description", zh_Hans="Another description"
)
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.label = I18nObject(
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
)
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["automation", "workflow"],
workflow_app_id=workflow_app_id,
)
# Verify all fields are preserved correctly
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.description.en_US == "Another description"
assert result.description.zh_Hans == "Another description"
assert result.icon == {"type": "emoji", "content": "⚙️"}
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
assert result.label.en_US == "Another Workflow Tool"
assert result.label.zh_Hans == "Another Workflow Tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["automation", "workflow"]
assert result.masked_credentials == {}
assert result.is_team_authorization is True
assert result.allow_delete is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []

View File

@ -121,7 +121,7 @@ import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
from services.vector_service import VectorService
@ -153,7 +153,7 @@ class VectorServiceTestDataFactory:
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
index_struct_dict: dict | None = None,
@ -494,7 +494,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality"
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -535,7 +535,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -568,7 +568,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -591,7 +591,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="high_quality"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -616,7 +616,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="parent_child_model", indexing_technique="economy"
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -669,7 +669,7 @@ class TestVectorService:
store when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -695,7 +695,7 @@ class TestVectorService:
index when using economy indexing with keywords.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -731,7 +731,7 @@ class TestVectorService:
index when using economy indexing without keywords.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -895,7 +895,7 @@ class TestVectorService:
when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -923,7 +923,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -951,7 +951,7 @@ class TestVectorService:
when there are new chunks, updated chunks, and deleted chunks.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1")
@ -993,7 +993,7 @@ class TestVectorService:
add_texts is called, not delete_by_ids.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1019,7 +1019,7 @@ class TestVectorService:
delete_by_ids is called, not add_texts.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1045,7 +1045,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1075,7 +1075,7 @@ class TestVectorService:
when using high_quality indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
@ -1099,7 +1099,7 @@ class TestVectorService:
using economy indexing.
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()

View File

@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.enums import DataSourceType
from tasks.clean_dataset_task import clean_dataset_task
@ -184,7 +184,7 @@ class TestErrorHandling:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -229,7 +229,7 @@ class TestPipelineAndWorkflowDeletion:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -265,7 +265,7 @@ class TestPipelineAndWorkflowDeletion:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -321,7 +321,7 @@ class TestSegmentAttachmentCleanup:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -366,7 +366,7 @@ class TestSegmentAttachmentCleanup:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -408,7 +408,7 @@ class TestEdgeCases:
clean_dataset_task(
dataset_id=dataset_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
@ -445,7 +445,7 @@ class TestIndexProcessorParameters:
- Dataset object with correct attributes is passed
"""
# Arrange
indexing_technique = "high_quality"
indexing_technique = IndexTechniqueType.HIGH_QUALITY
index_struct = '{"type": "paragraph"}'
# Act

View File

@ -15,7 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
@ -209,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id):
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset

View File

@ -49,9 +49,12 @@ vi.mock('@/service/use-tools', () => ({
// Mock Toast - need to verify notification calls
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (options: { type: string, message: string }) => mockToastNotify(options),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (message: string) => mockToastNotify({ type: 'success', message }),
error: (message: string) => mockToastNotify({ type: 'error', message }),
warning: (message: string) => mockToastNotify({ type: 'warning', message }),
info: (message: string) => mockToastNotify({ type: 'info', message }),
},
}))

View File

@ -33,9 +33,12 @@ vi.mock('@/service/use-tools', () => ({
}))
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (options: { type: string, message: string }) => mockToastNotify(options),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (message: string) => mockToastNotify({ type: 'success', message }),
error: (message: string) => mockToastNotify({ type: 'error', message }),
warning: (message: string) => mockToastNotify({ type: 'warning', message }),
info: (message: string) => mockToastNotify({ type: 'info', message }),
},
}))

View File

@ -3,7 +3,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types'
import type { PublishWorkflowParams } from '@/types/workflow'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { toast } from '@/app/components/base/ui/toast'
import { useAppContext } from '@/context/app-context'
import { useRouter } from '@/next/navigation'
import { createWorkflowToolProvider, saveWorkflowToolProvider } from '@/service/tools'
@ -188,14 +188,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) {
invalidateAllWorkflowTools()
onRefreshData?.()
invalidateDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
toast.success(t('api.actionSuccess', { ns: 'common' }))
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
toast.error((e as Error).message)
}
}
@ -209,14 +206,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) {
onRefreshData?.()
invalidateAllWorkflowTools()
invalidateDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
toast.success(t('api.actionSuccess', { ns: 'common' }))
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
toast.error((e as Error).message)
}
}

View File

@ -12,8 +12,8 @@ import Drawer from '@/app/components/base/drawer-plus'
import EmojiPicker from '@/app/components/base/emoji-picker'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { toast } from '@/app/components/base/ui/toast'
import LabelSelector from '@/app/components/tools/labels/selector'
import ConfirmModal from '@/app/components/tools/workflow-tool/confirm-modal'
import MethodSelector from '@/app/components/tools/workflow-tool/method-selector'
@ -129,10 +129,7 @@ const WorkflowToolAsModal: FC<Props> = ({
errorMessage = t('createTool.nameForToolCall', { ns: 'tools' }) + t('createTool.nameForToolCallTip', { ns: 'tools' })
if (errorMessage) {
Toast.notify({
type: 'error',
message: errorMessage,
})
toast.error(errorMessage)
return
}

View File

@ -0,0 +1,260 @@
import { render, screen } from '@testing-library/react'
import CandidateNodeMain from '../candidate-node-main'
import { CUSTOM_NODE } from '../constants'
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
import { BlockEnum } from '../types'
import { createNode } from './fixtures'
const mockUseEventListener = vi.hoisted(() => vi.fn())
const mockUseStoreApi = vi.hoisted(() => vi.fn())
const mockUseReactFlow = vi.hoisted(() => vi.fn())
const mockUseViewport = vi.hoisted(() => vi.fn())
const mockUseStore = vi.hoisted(() => vi.fn())
const mockUseWorkflowStore = vi.hoisted(() => vi.fn())
const mockUseHooks = vi.hoisted(() => vi.fn())
const mockCustomNode = vi.hoisted(() => vi.fn())
const mockCustomNoteNode = vi.hoisted(() => vi.fn())
const mockGetIterationStartNode = vi.hoisted(() => vi.fn())
const mockGetLoopStartNode = vi.hoisted(() => vi.fn())
vi.mock('ahooks', () => ({
useEventListener: (...args: unknown[]) => mockUseEventListener(...args),
}))
vi.mock('reactflow', () => ({
useStoreApi: () => mockUseStoreApi(),
useReactFlow: () => mockUseReactFlow(),
useViewport: () => mockUseViewport(),
Position: {
Left: 'left',
Right: 'right',
},
}))
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: { mousePosition: {
pageX: number
pageY: number
elementX: number
elementY: number
} }) => unknown) => mockUseStore(selector),
useWorkflowStore: () => mockUseWorkflowStore(),
}))
vi.mock('@/app/components/workflow/hooks', () => ({
useNodesInteractions: () => mockUseHooks().useNodesInteractions(),
useNodesSyncDraft: () => mockUseHooks().useNodesSyncDraft(),
useWorkflowHistory: () => mockUseHooks().useWorkflowHistory(),
useAutoGenerateWebhookUrl: () => mockUseHooks().useAutoGenerateWebhookUrl(),
WorkflowHistoryEvent: {
NodeAdd: 'NodeAdd',
NoteAdd: 'NoteAdd',
},
}))
vi.mock('@/app/components/workflow/nodes', () => ({
__esModule: true,
default: (props: { id: string }) => {
mockCustomNode(props)
return <div data-testid="candidate-custom-node">{props.id}</div>
},
}))
vi.mock('@/app/components/workflow/note-node', () => ({
__esModule: true,
default: (props: { id: string }) => {
mockCustomNoteNode(props)
return <div data-testid="candidate-note-node">{props.id}</div>
},
}))
vi.mock('@/app/components/workflow/utils', () => ({
getIterationStartNode: (...args: unknown[]) => mockGetIterationStartNode(...args),
getLoopStartNode: (...args: unknown[]) => mockGetLoopStartNode(...args),
}))
describe('CandidateNodeMain', () => {
const mockSetNodes = vi.fn()
const mockHandleNodeSelect = vi.fn()
const mockSaveStateToHistory = vi.fn()
const mockHandleSyncWorkflowDraft = vi.fn()
const mockAutoGenerateWebhookUrl = vi.fn()
const mockWorkflowStoreSetState = vi.fn()
const createNodesInteractions = () => ({
handleNodeSelect: mockHandleNodeSelect,
})
const createWorkflowHistory = () => ({
saveStateToHistory: mockSaveStateToHistory,
})
const createNodesSyncDraft = () => ({
handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft,
})
const createAutoGenerateWebhookUrl = () => mockAutoGenerateWebhookUrl
const eventHandlers: Partial<Record<'click' | 'contextmenu', (event: { preventDefault: () => void }) => void>> = {}
let nodes = [createNode({ id: 'existing-node' })]
beforeEach(() => {
vi.clearAllMocks()
nodes = [createNode({ id: 'existing-node' })]
eventHandlers.click = undefined
eventHandlers.contextmenu = undefined
mockUseEventListener.mockImplementation((event: 'click' | 'contextmenu', handler: (event: { preventDefault: () => void }) => void) => {
eventHandlers[event] = handler
})
mockUseStoreApi.mockReturnValue({
getState: () => ({
getNodes: () => nodes,
setNodes: mockSetNodes,
}),
})
mockUseReactFlow.mockReturnValue({
screenToFlowPosition: ({ x, y }: { x: number, y: number }) => ({ x: x + 10, y: y + 20 }),
})
mockUseViewport.mockReturnValue({ zoom: 1.5 })
mockUseStore.mockImplementation((selector: (state: { mousePosition: {
pageX: number
pageY: number
elementX: number
elementY: number
} }) => unknown) => selector({
mousePosition: {
pageX: 100,
pageY: 200,
elementX: 30,
elementY: 40,
},
}))
mockUseWorkflowStore.mockReturnValue({
setState: mockWorkflowStoreSetState,
})
mockUseHooks.mockReturnValue({
useNodesInteractions: createNodesInteractions,
useWorkflowHistory: createWorkflowHistory,
useNodesSyncDraft: createNodesSyncDraft,
useAutoGenerateWebhookUrl: createAutoGenerateWebhookUrl,
})
mockHandleSyncWorkflowDraft.mockImplementation((_isSync: boolean, _force: boolean, options?: { onSuccess?: () => void }) => {
options?.onSuccess?.()
})
mockGetIterationStartNode.mockReturnValue(createNode({ id: 'iteration-start' }))
mockGetLoopStartNode.mockReturnValue(createNode({ id: 'loop-start' }))
})
it('should render the candidate node and commit a webhook node on click', () => {
const candidateNode = createNode({
id: 'candidate-webhook',
type: CUSTOM_NODE,
data: {
type: BlockEnum.TriggerWebhook,
title: 'Webhook Candidate',
_isCandidate: true,
},
})
const { container } = render(<CandidateNodeMain candidateNode={candidateNode} />)
expect(screen.getByTestId('candidate-custom-node')).toHaveTextContent('candidate-webhook')
expect(container.firstChild).toHaveStyle({
left: '30px',
top: '40px',
transform: 'scale(1.5)',
})
eventHandlers.click?.({ preventDefault: vi.fn() })
expect(mockSetNodes).toHaveBeenCalledWith(expect.arrayContaining([
expect.objectContaining({ id: 'existing-node' }),
expect.objectContaining({
id: 'candidate-webhook',
position: { x: 110, y: 220 },
data: expect.objectContaining({ _isCandidate: false }),
}),
]))
expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodeAdd', { nodeId: 'candidate-webhook' })
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined })
expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledWith(true, true, expect.objectContaining({
onSuccess: expect.any(Function),
}))
expect(mockAutoGenerateWebhookUrl).toHaveBeenCalledWith('candidate-webhook')
expect(mockHandleNodeSelect).not.toHaveBeenCalled()
})
it('should save note candidates as notes and select the inserted note', () => {
const candidateNode = createNode({
id: 'candidate-note',
type: CUSTOM_NOTE_NODE,
data: {
type: BlockEnum.Code,
title: 'Note Candidate',
_isCandidate: true,
},
})
render(<CandidateNodeMain candidateNode={candidateNode} />)
expect(screen.getByTestId('candidate-note-node')).toHaveTextContent('candidate-note')
eventHandlers.click?.({ preventDefault: vi.fn() })
expect(mockSaveStateToHistory).toHaveBeenCalledWith('NoteAdd', { nodeId: 'candidate-note' })
expect(mockHandleNodeSelect).toHaveBeenCalledWith('candidate-note')
})
it('should append iteration and loop start helper nodes for control-flow candidates', () => {
const iterationNode = createNode({
id: 'candidate-iteration',
type: CUSTOM_NODE,
data: {
type: BlockEnum.Iteration,
title: 'Iteration Candidate',
_isCandidate: true,
},
})
const loopNode = createNode({
id: 'candidate-loop',
type: CUSTOM_NODE,
data: {
type: BlockEnum.Loop,
title: 'Loop Candidate',
_isCandidate: true,
},
})
const { rerender } = render(<CandidateNodeMain candidateNode={iterationNode} />)
eventHandlers.click?.({ preventDefault: vi.fn() })
expect(mockGetIterationStartNode).toHaveBeenCalledWith('candidate-iteration')
expect(mockSetNodes.mock.calls[0][0]).toEqual(expect.arrayContaining([
expect.objectContaining({ id: 'candidate-iteration' }),
expect.objectContaining({ id: 'iteration-start' }),
]))
rerender(<CandidateNodeMain candidateNode={loopNode} />)
eventHandlers.click?.({ preventDefault: vi.fn() })
expect(mockGetLoopStartNode).toHaveBeenCalledWith('candidate-loop')
expect(mockSetNodes.mock.calls[1][0]).toEqual(expect.arrayContaining([
expect.objectContaining({ id: 'candidate-loop' }),
expect.objectContaining({ id: 'loop-start' }),
]))
})
it('should clear the candidate node on contextmenu', () => {
const candidateNode = createNode({
id: 'candidate-context',
type: CUSTOM_NODE,
data: {
type: BlockEnum.Code,
title: 'Context Candidate',
_isCandidate: true,
},
})
render(<CandidateNodeMain candidateNode={candidateNode} />)
eventHandlers.contextmenu?.({ preventDefault: vi.fn() })
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined })
})
})

View File

@ -0,0 +1,235 @@
import type { ReactNode } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import { Position } from 'reactflow'
import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types'
import CustomEdge from '../custom-edge'
import { BlockEnum, NodeRunningStatus } from '../types'
const mockUseAvailableBlocks = vi.hoisted(() => vi.fn())
const mockUseNodesInteractions = vi.hoisted(() => vi.fn())
const mockBlockSelector = vi.hoisted(() => vi.fn())
const mockGradientRender = vi.hoisted(() => vi.fn())
vi.mock('reactflow', () => ({
BaseEdge: (props: {
id: string
path: string
style: {
stroke: string
strokeWidth: number
opacity: number
strokeDasharray?: string
}
}) => (
<div
data-testid="base-edge"
data-id={props.id}
data-path={props.path}
data-stroke={props.style.stroke}
data-stroke-width={props.style.strokeWidth}
data-opacity={props.style.opacity}
data-dasharray={props.style.strokeDasharray}
/>
),
EdgeLabelRenderer: ({ children }: { children?: ReactNode }) => <div data-testid="edge-label">{children}</div>,
getBezierPath: () => ['M 0 0', 24, 48],
Position: {
Right: 'right',
Left: 'left',
},
}))
vi.mock('@/app/components/workflow/hooks', () => ({
useAvailableBlocks: (...args: unknown[]) => mockUseAvailableBlocks(...args),
useNodesInteractions: () => mockUseNodesInteractions(),
}))
vi.mock('@/app/components/workflow/block-selector', () => ({
__esModule: true,
default: (props: {
open: boolean
onOpenChange: (open: boolean) => void
onSelect: (nodeType: string, pluginDefaultValue?: Record<string, unknown>) => void
availableBlocksTypes: string[]
triggerClassName?: () => string
}) => {
mockBlockSelector(props)
return (
<button
type="button"
data-testid="block-selector"
data-trigger-class={props.triggerClassName?.()}
onClick={() => {
props.onOpenChange(true)
props.onSelect('llm', { provider: 'openai' })
}}
>
{props.availableBlocksTypes.join(',')}
</button>
)
},
}))
vi.mock('@/app/components/workflow/custom-edge-linear-gradient-render', () => ({
__esModule: true,
default: (props: {
id: string
startColor: string
stopColor: string
}) => {
mockGradientRender(props)
return <div data-testid="edge-gradient">{props.id}</div>
},
}))
describe('CustomEdge', () => {
const mockHandleNodeAdd = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
mockUseNodesInteractions.mockReturnValue({
handleNodeAdd: mockHandleNodeAdd,
})
mockUseAvailableBlocks.mockImplementation((nodeType: BlockEnum) => {
if (nodeType === BlockEnum.Code)
return { availablePrevBlocks: ['code', 'llm'] }
return { availableNextBlocks: ['llm', 'tool'] }
})
})
it('should render a gradient edge and insert a node between the source and target', () => {
render(
<CustomEdge
id="edge-1"
source="source-node"
sourceHandleId="source"
target="target-node"
targetHandleId="target"
sourceX={100}
sourceY={120}
sourcePosition={Position.Right}
targetX={300}
targetY={220}
targetPosition={Position.Left}
selected={false}
data={{
sourceType: BlockEnum.Start,
targetType: BlockEnum.Code,
_sourceRunningStatus: NodeRunningStatus.Succeeded,
_targetRunningStatus: NodeRunningStatus.Failed,
_hovering: true,
_waitingRun: true,
_dimmed: true,
_isTemp: true,
isInIteration: true,
isInLoop: true,
} as never}
/>,
)
expect(screen.getByTestId('edge-gradient')).toHaveTextContent('edge-1')
expect(mockGradientRender).toHaveBeenCalledWith(expect.objectContaining({
id: 'edge-1',
startColor: 'var(--color-workflow-link-line-success-handle)',
stopColor: 'var(--color-workflow-link-line-error-handle)',
}))
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'url(#edge-1)')
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-opacity', '0.3')
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-dasharray', '8 8')
expect(screen.getByTestId('block-selector')).toHaveTextContent('llm')
expect(screen.getByTestId('block-selector').parentElement).toHaveStyle({
transform: 'translate(-50%, -50%) translate(24px, 48px)',
opacity: '0.7',
})
fireEvent.click(screen.getByTestId('block-selector'))
expect(mockHandleNodeAdd).toHaveBeenCalledWith(
{
nodeType: 'llm',
pluginDefaultValue: { provider: 'openai' },
},
{
prevNodeId: 'source-node',
prevNodeSourceHandle: 'source',
nextNodeId: 'target-node',
nextNodeTargetHandle: 'target',
},
)
})
it('should prefer the running stroke color when the edge is selected', () => {
render(
<CustomEdge
id="edge-selected"
source="source-node"
target="target-node"
sourceX={0}
sourceY={0}
sourcePosition={Position.Right}
targetX={100}
targetY={100}
targetPosition={Position.Left}
selected
data={{
sourceType: BlockEnum.Start,
targetType: BlockEnum.Code,
_sourceRunningStatus: NodeRunningStatus.Succeeded,
_targetRunningStatus: NodeRunningStatus.Running,
} as never}
/>,
)
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-handle)')
})
it('should use the fail-branch running color while the connected node is hovering', () => {
render(
<CustomEdge
id="edge-hover"
source="source-node"
sourceHandleId={ErrorHandleTypeEnum.failBranch}
target="target-node"
sourceX={0}
sourceY={0}
sourcePosition={Position.Right}
targetX={100}
targetY={100}
targetPosition={Position.Left}
selected={false}
data={{
sourceType: BlockEnum.Start,
targetType: BlockEnum.Code,
_connectedNodeIsHovering: true,
} as never}
/>,
)
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-failure-handle)')
})
it('should fall back to the default edge color when no highlight state is active', () => {
render(
<CustomEdge
id="edge-default"
source="source-node"
target="target-node"
sourceX={0}
sourceY={0}
sourcePosition={Position.Right}
targetX={100}
targetY={100}
targetPosition={Position.Left}
selected={false}
data={{
sourceType: BlockEnum.Start,
targetType: BlockEnum.Code,
} as never}
/>,
)
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-normal)')
expect(screen.getByTestId('block-selector')).toHaveAttribute('data-trigger-class', 'hover:scale-150 transition-all')
})
})

View File

@ -0,0 +1,114 @@
import type { Node } from '../types'
import { fireEvent, render, screen } from '@testing-library/react'
import NodeContextmenu from '../node-contextmenu'
const mockUseClickAway = vi.hoisted(() => vi.fn())
const mockUseNodes = vi.hoisted(() => vi.fn())
const mockUsePanelInteractions = vi.hoisted(() => vi.fn())
const mockUseStore = vi.hoisted(() => vi.fn())
const mockPanelOperatorPopup = vi.hoisted(() => vi.fn())
vi.mock('ahooks', () => ({
useClickAway: (...args: unknown[]) => mockUseClickAway(...args),
}))
vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({
__esModule: true,
default: () => mockUseNodes(),
}))
vi.mock('@/app/components/workflow/hooks', () => ({
usePanelInteractions: () => mockUsePanelInteractions(),
}))
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => mockUseStore(selector),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup', () => ({
__esModule: true,
default: (props: {
id: string
data: Node['data']
showHelpLink: boolean
onClosePopup: () => void
}) => {
mockPanelOperatorPopup(props)
return (
<button type="button" onClick={props.onClosePopup}>
{props.id}
:
{props.data.title}
</button>
)
},
}))
describe('NodeContextmenu', () => {
const mockHandleNodeContextmenuCancel = vi.fn()
let nodeMenu: { nodeId: string, left: number, top: number } | undefined
let nodes: Node[]
let clickAwayHandler: (() => void) | undefined
beforeEach(() => {
vi.clearAllMocks()
nodeMenu = undefined
nodes = [{
id: 'node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Node 1',
desc: '',
type: 'code' as never,
},
} as Node]
clickAwayHandler = undefined
mockUseClickAway.mockImplementation((handler: () => void) => {
clickAwayHandler = handler
})
mockUseNodes.mockImplementation(() => nodes)
mockUsePanelInteractions.mockReturnValue({
handleNodeContextmenuCancel: mockHandleNodeContextmenuCancel,
})
mockUseStore.mockImplementation((selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => selector({ nodeMenu }))
})
it('should stay hidden when the node menu is absent', () => {
render(<NodeContextmenu />)
expect(screen.queryByRole('button')).not.toBeInTheDocument()
expect(mockPanelOperatorPopup).not.toHaveBeenCalled()
})
it('should stay hidden when the referenced node cannot be found', () => {
nodeMenu = { nodeId: 'missing-node', left: 80, top: 120 }
render(<NodeContextmenu />)
expect(screen.queryByRole('button')).not.toBeInTheDocument()
expect(mockPanelOperatorPopup).not.toHaveBeenCalled()
})
it('should render the popup at the stored position and close on popup/click-away actions', () => {
nodeMenu = { nodeId: 'node-1', left: 80, top: 120 }
const { container } = render(<NodeContextmenu />)
expect(screen.getByRole('button')).toHaveTextContent('node-1:Node 1')
expect(mockPanelOperatorPopup).toHaveBeenCalledWith(expect.objectContaining({
id: 'node-1',
data: expect.objectContaining({ title: 'Node 1' }),
showHelpLink: true,
}))
expect(container.firstChild).toHaveStyle({
left: '80px',
top: '120px',
})
fireEvent.click(screen.getByRole('button'))
clickAwayHandler?.()
expect(mockHandleNodeContextmenuCancel).toHaveBeenCalledTimes(2)
})
})

View File

@ -0,0 +1,151 @@
import type { ReactNode } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import PanelContextmenu from '../panel-contextmenu'
const mockUseClickAway = vi.hoisted(() => vi.fn())
const mockUseTranslation = vi.hoisted(() => vi.fn())
const mockUseStore = vi.hoisted(() => vi.fn())
const mockUseNodesInteractions = vi.hoisted(() => vi.fn())
const mockUsePanelInteractions = vi.hoisted(() => vi.fn())
const mockUseWorkflowStartRun = vi.hoisted(() => vi.fn())
const mockUseOperator = vi.hoisted(() => vi.fn())
const mockUseDSL = vi.hoisted(() => vi.fn())
vi.mock('ahooks', () => ({
useClickAway: (...args: unknown[]) => mockUseClickAway(...args),
}))
vi.mock('react-i18next', () => ({
useTranslation: () => mockUseTranslation(),
}))
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: {
panelMenu?: { left: number, top: number }
clipboardElements: unknown[]
setShowImportDSLModal: (visible: boolean) => void
}) => unknown) => mockUseStore(selector),
}))
vi.mock('@/app/components/workflow/hooks', () => ({
useNodesInteractions: () => mockUseNodesInteractions(),
usePanelInteractions: () => mockUsePanelInteractions(),
useWorkflowStartRun: () => mockUseWorkflowStartRun(),
useDSL: () => mockUseDSL(),
}))
vi.mock('@/app/components/workflow/operator/hooks', () => ({
useOperator: () => mockUseOperator(),
}))
vi.mock('@/app/components/workflow/operator/add-block', () => ({
__esModule: true,
default: ({ renderTrigger }: { renderTrigger: () => ReactNode }) => (
<div data-testid="add-block">{renderTrigger()}</div>
),
}))
vi.mock('@/app/components/base/divider', () => ({
__esModule: true,
default: ({ className }: { className?: string }) => <div data-testid="divider" className={className} />,
}))
vi.mock('@/app/components/workflow/shortcuts-name', () => ({
__esModule: true,
default: ({ keys }: { keys: string[] }) => <span data-testid={`shortcut-${keys.join('-')}`}>{keys.join('+')}</span>,
}))
describe('PanelContextmenu', () => {
const mockHandleNodesPaste = vi.fn()
const mockHandlePaneContextmenuCancel = vi.fn()
const mockHandleStartWorkflowRun = vi.fn()
const mockHandleAddNote = vi.fn()
const mockExportCheck = vi.fn()
const mockSetShowImportDSLModal = vi.fn()
let panelMenu: { left: number, top: number } | undefined
let clipboardElements: unknown[]
let clickAwayHandler: (() => void) | undefined
beforeEach(() => {
vi.clearAllMocks()
panelMenu = undefined
clipboardElements = []
clickAwayHandler = undefined
mockUseClickAway.mockImplementation((handler: () => void) => {
clickAwayHandler = handler
})
mockUseTranslation.mockReturnValue({
t: (key: string) => key,
})
mockUseStore.mockImplementation((selector: (state: {
panelMenu?: { left: number, top: number }
clipboardElements: unknown[]
setShowImportDSLModal: (visible: boolean) => void
}) => unknown) => selector({
panelMenu,
clipboardElements,
setShowImportDSLModal: mockSetShowImportDSLModal,
}))
mockUseNodesInteractions.mockReturnValue({
handleNodesPaste: mockHandleNodesPaste,
})
mockUsePanelInteractions.mockReturnValue({
handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel,
})
mockUseWorkflowStartRun.mockReturnValue({
handleStartWorkflowRun: mockHandleStartWorkflowRun,
})
mockUseOperator.mockReturnValue({
handleAddNote: mockHandleAddNote,
})
mockUseDSL.mockReturnValue({
exportCheck: mockExportCheck,
})
})
it('should stay hidden when the panel menu is absent', () => {
render(<PanelContextmenu />)
expect(screen.queryByTestId('add-block')).not.toBeInTheDocument()
})
it('should keep paste disabled when the clipboard is empty', () => {
panelMenu = { left: 24, top: 48 }
render(<PanelContextmenu />)
fireEvent.click(screen.getByText('common.pasteHere'))
expect(mockHandleNodesPaste).not.toHaveBeenCalled()
expect(mockHandlePaneContextmenuCancel).not.toHaveBeenCalled()
})
it('should render actions, position the menu, and execute each action', () => {
panelMenu = { left: 24, top: 48 }
clipboardElements = [{ id: 'copied-node' }]
const { container } = render(<PanelContextmenu />)
expect(screen.getByTestId('add-block')).toHaveTextContent('common.addBlock')
expect(screen.getByTestId('shortcut-alt-r')).toHaveTextContent('alt+r')
expect(screen.getByTestId('shortcut-ctrl-v')).toHaveTextContent('ctrl+v')
expect(container.firstChild).toHaveStyle({
left: '24px',
top: '48px',
})
fireEvent.click(screen.getByText('nodes.note.addNote'))
fireEvent.click(screen.getByText('common.run'))
fireEvent.click(screen.getByText('common.pasteHere'))
fireEvent.click(screen.getByText('export'))
fireEvent.click(screen.getByText('common.importDSL'))
clickAwayHandler?.()
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)
expect(mockHandleStartWorkflowRun).toHaveBeenCalledTimes(1)
expect(mockHandleNodesPaste).toHaveBeenCalledTimes(1)
expect(mockExportCheck).toHaveBeenCalledTimes(1)
expect(mockSetShowImportDSLModal).toHaveBeenCalledWith(true)
expect(mockHandlePaneContextmenuCancel).toHaveBeenCalledTimes(4)
})
})

View File

@ -1,7 +1,7 @@
import type { EventEmitter } from 'ahooks/lib/useEventEmitter'
import type { EventEmitterValue } from '@/context/event-emitter'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { ToastContext } from '@/app/components/base/toast/context'
import { toast } from '@/app/components/base/ui/toast'
import { EventEmitterContext } from '@/context/event-emitter'
import { DSLImportStatus } from '@/models/app'
import UpdateDSLModal from '../update-dsl-modal'
@ -16,10 +16,17 @@ class MockFileReader {
}
vi.stubGlobal('FileReader', MockFileReader as unknown as typeof FileReader)
const mockNotify = vi.fn()
const mockEmit = vi.fn()
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
error: vi.fn(),
info: vi.fn(),
success: vi.fn(),
warning: vi.fn(),
},
}))
const mockImportDSL = vi.fn()
const mockImportDSLConfirm = vi.fn()
vi.mock('@/service/apps', () => ({
@ -59,6 +66,7 @@ vi.mock('@/app/components/app/create-from-dsl-modal/uploader', () => ({
}))
describe('UpdateDSLModal', () => {
const mockToastError = vi.mocked(toast.error)
const defaultProps = {
onCancel: vi.fn(),
onBackup: vi.fn(),
@ -91,11 +99,9 @@ describe('UpdateDSLModal', () => {
const eventEmitter = { emit: mockEmit } as unknown as EventEmitter<EventEmitterValue>
return render(
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
<EventEmitterContext.Provider value={{ eventEmitter }}>
<UpdateDSLModal {...props} />
</EventEmitterContext.Provider>
</ToastContext.Provider>,
<EventEmitterContext.Provider value={{ eventEmitter }}>
<UpdateDSLModal {...props} />
</EventEmitterContext.Provider>,
)
}
@ -152,9 +158,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
@ -233,9 +237,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
expect(mockImportDSL).not.toHaveBeenCalled()
@ -254,9 +256,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
@ -274,9 +274,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
@ -305,9 +303,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
@ -334,9 +330,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
@ -365,9 +359,7 @@ describe('UpdateDSLModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
}))
expect(mockToastError).toHaveBeenCalled()
})
})
})

View File

@ -114,9 +114,12 @@ vi.mock('@/service/use-tools', () => ({
useInvalidateAllMCPTools: vi.fn(),
}))
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (payload: unknown) => mockNotify(payload),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (message: string) => mockNotify({ type: 'success', message }),
error: (message: string) => mockNotify({ type: 'error', message }),
warning: (message: string) => mockNotify({ type: 'warning', message }),
info: (message: string) => mockNotify({ type: 'info', message }),
},
}))

View File

@ -16,7 +16,7 @@ import {
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import Toast from '@/app/components/base/toast'
import { toast } from '@/app/components/base/ui/toast'
import SearchBox from '@/app/components/plugins/marketplace/search-box'
import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal'
import AllTools from '@/app/components/workflow/block-selector/all-tools'
@ -137,10 +137,7 @@ const ToolPicker: FC<Props> = ({
const doCreateCustomToolCollection = async (data: CustomCollectionBackend) => {
await createCustomCollection(data)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
toast.success(t('api.actionSuccess', { ns: 'common' }))
hideEditCustomCollectionModal()
handleAddedCustomTool()
}

View File

@ -60,9 +60,12 @@ vi.mock('@/service/use-workflow', () => ({
}),
}))
vi.mock('../../../base/toast', () => ({
default: {
notify: (payload: unknown) => mockNotify(payload),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (message: string) => mockNotify({ type: 'success', message }),
error: (message: string) => mockNotify({ type: 'error', message }),
warning: (message: string) => mockNotify({ type: 'warning', message }),
info: (message: string) => mockNotify({ type: 'info', message }),
},
}))

View File

@ -46,10 +46,13 @@ vi.mock('../../hooks/use-dynamic-test-run-options', () => ({
useDynamicTestRunOptions: () => mockDynamicOptions,
}))
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (message: string) => mockNotify({ type: 'success', message }),
error: (message: string) => mockNotify({ type: 'error', message }),
warning: (message: string) => mockNotify({ type: 'warning', message }),
info: (message: string) => mockNotify({ type: 'info', message }),
},
}))
vi.mock('@/app/components/base/amplitude', () => ({

View File

@ -4,11 +4,11 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import { toast } from '@/app/components/base/ui/toast'
import useTheme from '@/hooks/use-theme'
import { useInvalidAllLastRun, useRestoreWorkflow } from '@/service/use-workflow'
import { getFlowPrefix } from '@/service/utils'
import { cn } from '@/utils/classnames'
import Toast from '../../base/toast'
import {
useWorkflowRefreshDraft,
useWorkflowRun,
@ -65,18 +65,12 @@ const HeaderInRestoring = ({
workflowStore.setState({ isRestoring: false })
workflowStore.setState({ backupDraft: undefined })
handleRefreshWorkflowDraft()
Toast.notify({
type: 'success',
message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
})
toast.success(t('versionHistory.action.restoreSuccess', { ns: 'workflow' }))
deleteAllInspectVars()
invalidAllLastRun()
}
catch {
Toast.notify({
type: 'error',
message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
})
toast.error(t('versionHistory.action.restoreFailure', { ns: 'workflow' }))
}
finally {
onRestoreSettled?.()

View File

@ -5,7 +5,7 @@ import { useCallback, useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
import { useToastContext } from '@/app/components/base/toast/context'
import { toast } from '@/app/components/base/ui/toast'
import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks'
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import { useStore } from '@/app/components/workflow/store'
@ -41,7 +41,6 @@ const RunMode = ({
const dynamicOptions = useDynamicTestRunOptions()
const testRunMenuRef = useRef<TestRunMenuRef>(null)
const { notify } = useToastContext()
useEffect(() => {
// @ts-expect-error - Dynamic property for backward compatibility with keyboard shortcuts
@ -66,7 +65,7 @@ const RunMode = ({
isValid = false
})
if (!isValid) {
notify({ type: 'error', message: t('panel.checklistTip', { ns: 'workflow' }) })
toast.error(t('panel.checklistTip', { ns: 'workflow' }))
return
}
@ -98,7 +97,7 @@ const RunMode = ({
// Placeholder for trigger-specific execution logic for schedule, webhook, plugin types
console.log('TODO: Handle trigger execution for type:', option.type, 'nodeId:', option.nodeId)
}
}, [warningNodes, notify, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow])
}, [warningNodes, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow])
const { eventEmitter } = useEventEmitterContextContext()
eventEmitter?.useSubscription((v: any) => {

View File

@ -0,0 +1,61 @@
import { render } from '@testing-library/react'
import HelpLine from '../index'
const mockUseViewport = vi.hoisted(() => vi.fn())
const mockUseStore = vi.hoisted(() => vi.fn())
vi.mock('reactflow', () => ({
useViewport: () => mockUseViewport(),
}))
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: {
helpLineHorizontal?: { top: number, left: number, width: number }
helpLineVertical?: { top: number, left: number, height: number }
}) => unknown) => mockUseStore(selector),
}))
describe('HelpLine', () => {
let helpLineHorizontal: { top: number, left: number, width: number } | undefined
let helpLineVertical: { top: number, left: number, height: number } | undefined
beforeEach(() => {
vi.clearAllMocks()
helpLineHorizontal = undefined
helpLineVertical = undefined
mockUseViewport.mockReturnValue({ x: 10, y: 20, zoom: 2 })
mockUseStore.mockImplementation((selector: (state: {
helpLineHorizontal?: { top: number, left: number, width: number }
helpLineVertical?: { top: number, left: number, height: number }
}) => unknown) => selector({
helpLineHorizontal,
helpLineVertical,
}))
})
it('should render nothing when both help lines are absent', () => {
const { container } = render(<HelpLine />)
expect(container).toBeEmptyDOMElement()
})
it('should render the horizontal and vertical guide lines using viewport offsets and zoom', () => {
helpLineHorizontal = { top: 30, left: 40, width: 50 }
helpLineVertical = { top: 60, left: 70, height: 80 }
const { container } = render(<HelpLine />)
const [horizontal, vertical] = Array.from(container.querySelectorAll('div'))
expect(horizontal).toHaveStyle({
top: '80px',
left: '90px',
width: '100px',
})
expect(vertical).toHaveStyle({
top: '140px',
left: '150px',
height: '160px',
})
})
})

View File

@ -89,8 +89,13 @@ vi.mock('../index', () => ({
useNodesMetaData: () => ({ nodes: [], nodesMap: mockNodesMap }),
}))
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({ notify: vi.fn() }),
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: vi.fn(),
error: vi.fn(),
warning: vi.fn(),
info: vi.fn(),
},
}))
vi.mock('@/context/i18n', () => ({

View File

@ -0,0 +1,171 @@
import type { ModelConfig, VisionSetting } from '@/app/components/workflow/types'
import { act, renderHook } from '@testing-library/react'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { Resolution } from '@/types/app'
import useConfigVision from '../use-config-vision'
const mockUseTextGenerationCurrentProviderAndModelAndModelList = vi.hoisted(() => vi.fn())
const mockUseIsChatMode = vi.hoisted(() => vi.fn())
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
useTextGenerationCurrentProviderAndModelAndModelList: (...args: unknown[]) =>
mockUseTextGenerationCurrentProviderAndModelAndModelList(...args),
}))
vi.mock('../use-workflow', () => ({
useIsChatMode: () => mockUseIsChatMode(),
}))
const createModel = (overrides: Partial<ModelConfig> = {}): ModelConfig => ({
provider: 'openai',
name: 'gpt-4o',
mode: 'chat',
completion_params: [],
...overrides,
})
const createVisionPayload = (overrides: Partial<{ enabled: boolean, configs?: VisionSetting }> = {}) => ({
enabled: false,
...overrides,
})
describe('useConfigVision', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseIsChatMode.mockReturnValue(false)
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
currentModel: {
features: [],
},
})
})
it('should expose vision capability and enable default chat configs for vision models', () => {
const onChange = vi.fn()
mockUseIsChatMode.mockReturnValue(true)
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
currentModel: {
features: [ModelFeatureEnum.vision],
},
})
const { result } = renderHook(() => useConfigVision(createModel(), {
payload: createVisionPayload(),
onChange,
}))
expect(result.current.isVisionModel).toBe(true)
act(() => {
result.current.handleVisionResolutionEnabledChange(true)
})
expect(onChange).toHaveBeenCalledWith({
enabled: true,
configs: {
detail: Resolution.high,
variable_selector: ['sys', 'files'],
},
})
})
it('should clear configs when disabling vision resolution', () => {
const onChange = vi.fn()
const { result } = renderHook(() => useConfigVision(createModel(), {
payload: createVisionPayload({
enabled: true,
configs: {
detail: Resolution.low,
variable_selector: ['node', 'files'],
},
}),
onChange,
}))
act(() => {
result.current.handleVisionResolutionEnabledChange(false)
})
expect(onChange).toHaveBeenCalledWith({
enabled: false,
})
})
it('should update the resolution config payload directly', () => {
const onChange = vi.fn()
const config: VisionSetting = {
detail: Resolution.low,
variable_selector: ['upstream', 'images'],
}
const { result } = renderHook(() => useConfigVision(createModel(), {
payload: createVisionPayload({ enabled: true }),
onChange,
}))
act(() => {
result.current.handleVisionResolutionChange(config)
})
expect(onChange).toHaveBeenCalledWith({
enabled: true,
configs: config,
})
})
it('should disable vision settings when the selected model is no longer a vision model', () => {
const onChange = vi.fn()
const { result } = renderHook(() => useConfigVision(createModel(), {
payload: createVisionPayload({
enabled: true,
configs: {
detail: Resolution.high,
variable_selector: ['sys', 'files'],
},
}),
onChange,
}))
act(() => {
result.current.handleModelChanged()
})
expect(onChange).toHaveBeenCalledWith({
enabled: false,
})
})
it('should reset enabled vision configs when the model changes but still supports vision', () => {
const onChange = vi.fn()
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
currentModel: {
features: [ModelFeatureEnum.vision],
},
})
const { result } = renderHook(() => useConfigVision(createModel(), {
payload: createVisionPayload({
enabled: true,
configs: {
detail: Resolution.low,
variable_selector: ['old', 'files'],
},
}),
onChange,
}))
act(() => {
result.current.handleModelChanged()
})
expect(onChange).toHaveBeenCalledWith({
enabled: true,
configs: {
detail: Resolution.high,
variable_selector: [],
},
})
})
})

View File

@ -0,0 +1,146 @@
import { renderHook } from '@testing-library/react'
import { BlockEnum } from '../../types'
import { useDynamicTestRunOptions } from '../use-dynamic-test-run-options'
const mockUseTranslation = vi.hoisted(() => vi.fn())
const mockUseNodes = vi.hoisted(() => vi.fn())
const mockUseStore = vi.hoisted(() => vi.fn())
const mockUseAllTriggerPlugins = vi.hoisted(() => vi.fn())
const mockGetWorkflowEntryNode = vi.hoisted(() => vi.fn())
vi.mock('react-i18next', () => ({
useTranslation: () => mockUseTranslation(),
}))
vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({
__esModule: true,
default: () => mockUseNodes(),
}))
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: {
buildInTools: unknown[]
customTools: unknown[]
workflowTools: unknown[]
mcpTools: unknown[]
}) => unknown) => mockUseStore(selector),
}))
vi.mock('@/service/use-triggers', () => ({
useAllTriggerPlugins: () => mockUseAllTriggerPlugins(),
}))
vi.mock('@/app/components/workflow/utils/workflow-entry', () => ({
getWorkflowEntryNode: (...args: unknown[]) => mockGetWorkflowEntryNode(...args),
}))
describe('useDynamicTestRunOptions', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseTranslation.mockReturnValue({
t: (key: string) => key,
})
mockUseStore.mockImplementation((selector: (state: {
buildInTools: unknown[]
customTools: unknown[]
workflowTools: unknown[]
mcpTools: unknown[]
}) => unknown) => selector({
buildInTools: [],
customTools: [],
workflowTools: [],
mcpTools: [],
}))
mockUseAllTriggerPlugins.mockReturnValue({
data: [{
name: 'plugin-provider',
icon: '/plugin-icon.png',
}],
})
})
it('should build user input, trigger options, and a run-all option from workflow nodes', () => {
mockUseNodes.mockReturnValue([
{
id: 'start-1',
data: { type: BlockEnum.Start, title: 'User Input' },
},
{
id: 'schedule-1',
data: { type: BlockEnum.TriggerSchedule, title: 'Daily Schedule' },
},
{
id: 'webhook-1',
data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' },
},
{
id: 'plugin-1',
data: {
type: BlockEnum.TriggerPlugin,
title: '',
plugin_name: 'Plugin Trigger',
provider_id: 'plugin-provider',
},
},
])
const { result } = renderHook(() => useDynamicTestRunOptions())
expect(result.current.userInput).toEqual(expect.objectContaining({
id: 'start-1',
type: 'user_input',
name: 'User Input',
nodeId: 'start-1',
enabled: true,
}))
expect(result.current.triggers).toEqual([
expect.objectContaining({
id: 'schedule-1',
type: 'schedule',
name: 'Daily Schedule',
nodeId: 'schedule-1',
}),
expect.objectContaining({
id: 'webhook-1',
type: 'webhook',
name: 'Webhook Trigger',
nodeId: 'webhook-1',
}),
expect.objectContaining({
id: 'plugin-1',
type: 'plugin',
name: 'Plugin Trigger',
nodeId: 'plugin-1',
}),
])
expect(result.current.runAll).toEqual(expect.objectContaining({
id: 'run-all',
type: 'all',
relatedNodeIds: ['schedule-1', 'webhook-1', 'plugin-1'],
}))
})
it('should fall back to the workflow entry node and omit run-all when only one trigger exists', () => {
mockUseNodes.mockReturnValue([
{
id: 'webhook-1',
data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' },
},
])
mockGetWorkflowEntryNode.mockReturnValue({
id: 'fallback-start',
data: { type: BlockEnum.Start, title: '' },
})
const { result } = renderHook(() => useDynamicTestRunOptions())
expect(result.current.userInput).toEqual(expect.objectContaining({
id: 'fallback-start',
type: 'user_input',
name: 'blocks.start',
nodeId: 'fallback-start',
}))
expect(result.current.triggers).toHaveLength(1)
expect(result.current.runAll).toBeUndefined()
})
})

Some files were not shown because too many files have changed in this diff Show More