mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge branch 'test/workflow-part-8' into test/workflow-app
This commit is contained in:
commit
77e7f0a7de
@ -10,7 +10,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
@ -86,7 +86,7 @@ def migrate_annotation_vector_database():
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
@ -178,7 +178,9 @@ def migrate_knowledge_vector_database():
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
select(Dataset)
|
||||
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -355,7 +356,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
@ -436,7 +437,7 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
data["embedding_model_provider"] = str(provider_id)
|
||||
@ -454,7 +455,7 @@ class DatasetApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -485,7 +486,7 @@ class DatasetApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == "high_quality"
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -27,6 +27,7 @@ from core.model_manager import ModelManager
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
@ -449,7 +450,7 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
@ -463,7 +464,7 @@ class DatasetInitApi(Resource):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
@ -1337,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
|
||||
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
)
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
@ -153,9 +154,14 @@ class DatasetListApi(DatasetApiResource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if payload.indexing_technique == "high_quality" or embedding_model_provider:
|
||||
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
|
||||
@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType, ConversationFromSource
|
||||
@ -50,7 +51,7 @@ class AnnotationReplyFeature:
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
@ -271,7 +271,7 @@ class IndexingRunner:
|
||||
doc_form: str | None = None,
|
||||
doc_language: str = "English",
|
||||
dataset_id: str | None = None,
|
||||
indexing_technique: str = "economy",
|
||||
indexing_technique: str = IndexTechniqueType.ECONOMY,
|
||||
) -> IndexingEstimate:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
@ -289,7 +289,7 @@ class IndexingRunner:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@ -303,7 +303,7 @@ class IndexingRunner:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@ -573,7 +573,7 @@ class IndexingRunner:
|
||||
"""
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@ -587,7 +587,7 @@ class IndexingRunner:
|
||||
create_keyword_thread = None
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
):
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@ -597,7 +597,7 @@ class IndexingRunner:
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
|
||||
@ -628,7 +628,7 @@ class IndexingRunner:
|
||||
tokens += future.result()
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
and create_keyword_thread is not None
|
||||
):
|
||||
create_keyword_thread.join()
|
||||
@ -654,7 +654,7 @@ class IndexingRunner:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
document_ids = [document.metadata["doc_id"] for document in documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
@ -764,7 +764,7 @@ class IndexingRunner:
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -71,7 +72,7 @@ class DatasetDocumentStore:
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
@ -159,7 +160,7 @@ class IndexProcessor:
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
preview_output = self.format_preview(chunk_structure, chunks)
|
||||
if indexing_technique != "high_quality":
|
||||
if indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
|
||||
@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
child_documents = document.children
|
||||
@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
vector = Vector(dataset)
|
||||
@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=True)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_child_documents = []
|
||||
all_multimodal_documents = []
|
||||
for doc in documents:
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
else:
|
||||
|
||||
@ -675,7 +675,7 @@ class DatasetRetrieval:
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if selected_dataset.indexing_technique == "economy":
|
||||
if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
@ -752,7 +752,7 @@ class DatasetRetrieval:
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
if index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
@ -1068,7 +1068,7 @@ class DatasetRetrieval:
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
|
||||
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -21,7 +22,7 @@ class SummaryIndex:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset or dataset.indexing_technique != "high_quality":
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
if summary_index_setting is None:
|
||||
|
||||
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
if dataset.indexing_technique != IndexTechniqueType.ECONOMY:
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
@ -137,7 +137,7 @@ class Dataset(Base):
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@ -940,7 +940,9 @@ class AccountTrialAppRecord(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -1849,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
source: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
|
||||
@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
@ -228,7 +228,7 @@ class DatasetService:
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
if embedding_model_provider and embedding_model_name:
|
||||
# check if embedding model setting is valid
|
||||
@ -254,7 +254,10 @@ class DatasetService:
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
dataset = Dataset(
|
||||
name=name,
|
||||
indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None,
|
||||
)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.description = description
|
||||
dataset.created_by = account.id
|
||||
@ -349,7 +352,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -717,13 +720,13 @@ class DatasetService:
|
||||
if "indexing_technique" not in data:
|
||||
return None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
if data["indexing_technique"] == IndexTechniqueType.ECONOMY:
|
||||
# Remove embedding model configuration for economy mode
|
||||
filtered_data["embedding_model"] = None
|
||||
filtered_data["embedding_model_provider"] = None
|
||||
filtered_data["collection_binding_id"] = None
|
||||
return "remove"
|
||||
elif data["indexing_technique"] == "high_quality":
|
||||
elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
# Configure embedding model for high quality mode
|
||||
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
|
||||
return "add"
|
||||
@ -953,8 +956,8 @@ class DatasetService:
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, # ignore type error
|
||||
@ -976,7 +979,7 @@ class DatasetService:
|
||||
embedding_model_name,
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
@ -991,9 +994,9 @@ class DatasetService:
|
||||
action = None
|
||||
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
|
||||
# if update indexing_technique
|
||||
if knowledge_configuration.indexing_technique == "economy":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||
elif knowledge_configuration.indexing_technique == "high_quality":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
@ -1018,7 +1021,7 @@ class DatasetService:
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -1029,7 +1032,7 @@ class DatasetService:
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
@ -1089,7 +1092,7 @@ class DatasetService:
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
@ -1907,8 +1910,8 @@ class DocumentService:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is invalid")
|
||||
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique)
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
@ -2689,7 +2692,7 @@ class DocumentService:
|
||||
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
assert knowledge_config.embedding_model_provider
|
||||
assert knowledge_config.embedding_model
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@ -2712,7 +2715,7 @@ class DocumentService:
|
||||
tenant_id=tenant_id,
|
||||
name="",
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||
indexing_technique=knowledge_config.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique),
|
||||
created_by=account.id,
|
||||
embedding_model=knowledge_config.embedding_model,
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
@ -3125,7 +3128,7 @@ class SegmentService:
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3208,7 +3211,7 @@ class SegmentService:
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3230,7 +3233,7 @@ class SegmentService:
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
@ -3345,7 +3348,7 @@ class SegmentService:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -3382,7 +3385,7 @@ class SegmentService:
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
|
||||
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
@ -3409,7 +3412,7 @@ class SegmentService:
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3449,7 +3452,7 @@ class SegmentService:
|
||||
db.session.commit()
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -3481,7 +3484,7 @@ class SegmentService:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# Handle summary index when content changed
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
|
||||
@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from core.helper import ssrf_proxy
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
@ -311,13 +312,13 @@ class RagPipelineDslService:
|
||||
"icon_background": icon_background,
|
||||
"icon_url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.where(
|
||||
@ -343,7 +344,7 @@ class RagPipelineDslService:
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
@ -443,18 +444,18 @@ class RagPipelineDslService:
|
||||
"icon_background": icon_background,
|
||||
"icon_url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.where(
|
||||
@ -480,7 +481,7 @@ class RagPipelineDslService:
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
@ -772,7 +773,7 @@ class RagPipelineDslService:
|
||||
)
|
||||
case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE:
|
||||
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
|
||||
if knowledge_index_entity.indexing_technique == "high_quality":
|
||||
if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_index_entity.embedding_model_provider:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
|
||||
@ -9,7 +9,7 @@ from flask_login import current_user
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
@ -105,29 +105,29 @@ class RagPipelineTransformService:
|
||||
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.file-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.file-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.notion-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.notion-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.website-crawl-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.website-crawl-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
@ -170,11 +170,11 @@ class RagPipelineTransformService:
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
|
||||
if retrieval_model:
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_model
|
||||
else:
|
||||
|
||||
@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -140,7 +141,7 @@ class SummaryIndexService:
|
||||
session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
|
||||
If not provided, creates a new session and commits automatically.
|
||||
"""
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.warning(
|
||||
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
|
||||
dataset.id,
|
||||
@ -724,7 +725,7 @@ class SummaryIndexService:
|
||||
List of created DocumentSegmentSummary instances
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
|
||||
dataset.id,
|
||||
@ -851,7 +852,7 @@ class SummaryIndexService:
|
||||
)
|
||||
|
||||
# Remove from vector database (but keep records)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
try:
|
||||
@ -889,7 +890,7 @@ class SummaryIndexService:
|
||||
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
|
||||
"""
|
||||
# Only enable summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
@ -981,7 +982,7 @@ class SummaryIndexService:
|
||||
return
|
||||
|
||||
# Delete from vector database
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
vector = Vector(dataset)
|
||||
@ -1012,7 +1013,7 @@ class SummaryIndexService:
|
||||
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
|
||||
"""
|
||||
# Only update summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return None
|
||||
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
|
||||
@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
@ -45,7 +45,7 @@ class VectorService:
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -112,7 +112,7 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
@ -197,7 +197,7 @@ class VectorService:
|
||||
"dataset_id": child_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# save vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.add_texts([child_document], duplicate_check=True)
|
||||
@ -237,7 +237,7 @@ class VectorService:
|
||||
delete_node_ids.append(update_child_chunk.index_node_id)
|
||||
for delete_child_chunk in delete_child_chunks:
|
||||
delete_node_ids.append(delete_child_chunk.index_node_id)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
if delete_node_ids:
|
||||
@ -252,7 +252,7 @@ class VectorService:
|
||||
|
||||
@classmethod
|
||||
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
attachments = segment.attachments
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -36,7 +37,7 @@ def add_annotation_to_index_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import exists, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -64,7 +65,7 @@ def enable_annotation_reply_task(
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
@ -93,7 +94,7 @@ def enable_annotation_reply_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -37,7 +38,7 @@ def update_annotation_to_index_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy import func
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -120,7 +120,7 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
|
||||
@ -10,7 +10,7 @@ from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -127,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# expire all session to get latest document's indexing status
|
||||
|
||||
@ -7,6 +7,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids:
|
||||
return
|
||||
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for dataset {dataset_id}: "
|
||||
|
||||
@ -9,7 +9,7 @@ from celery import shared_task
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -53,7 +53,7 @@ def regenerate_summary_index_task(
|
||||
return
|
||||
|
||||
# Only regenerate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary regeneration for dataset {dataset_id}: "
|
||||
|
||||
@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
@ -39,7 +39,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
provider="dify",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -460,7 +460,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
provider="dify",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
|
||||
from models.enums import DataSourceType
|
||||
@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
name=name,
|
||||
description="Test description",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -1245,3 +1245,51 @@ class TestAppService:
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
def test_get_app_code_by_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_code_by_id raises ValueError when site is missing."""
|
||||
from uuid import uuid4
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_code_by_id(str(uuid4()))
|
||||
|
||||
def test_get_app_id_by_code_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_id_by_code raises ValueError when code does not exist."""
|
||||
from services.app_service import AppService
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_id_by_code("nonexistent-code")
|
||||
|
||||
def test_get_app_meta_returns_empty_when_workflow_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_meta returns empty tool_icons when workflow is None."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
workflow_app = SimpleNamespace(mode="workflow", workflow=None)
|
||||
|
||||
meta = app_service.get_app_meta(workflow_app)
|
||||
assert meta == {"tool_icons": {}}
|
||||
|
||||
def test_get_app_meta_returns_empty_when_model_config_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_meta returns empty tool_icons when app_model_config is None."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
chat_app = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
|
||||
meta = app_service.get_app_meta(chat_app)
|
||||
assert meta == {"tool_icons": {}}
|
||||
|
||||
@ -9,6 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory:
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -11,7 +11,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -63,7 +63,7 @@ class DatasetServiceIntegrationDataFactory:
|
||||
name: str = "Test Dataset",
|
||||
description: str | None = "Test description",
|
||||
provider: str = "vendor",
|
||||
indexing_technique: str | None = "high_quality",
|
||||
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
|
||||
permission: str = DatasetPermissionEnum.ONLY_ME,
|
||||
retrieval_model: dict | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
@ -157,13 +157,13 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Economy Dataset",
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
|
||||
@ -181,13 +181,13 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="High Quality Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
|
||||
@ -273,7 +273,7 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Dataset With Reranking",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
@ -306,7 +306,7 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Custom Embedding Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model_name=embedding_model_name,
|
||||
@ -314,7 +314,7 @@ class TestDatasetServiceCreateDataset:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert result.embedding_model_provider == embedding_provider
|
||||
assert result.embedding_model == embedding_model_name
|
||||
mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name)
|
||||
@ -589,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_document(
|
||||
@ -685,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=str(uuid4()),
|
||||
)
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": {
|
||||
"search_method": "full_text_search",
|
||||
"top_k": 10,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
@ -109,7 +109,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
@ -208,7 +208,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
|
||||
@ -12,6 +12,7 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory:
|
||||
name=f"Test Dataset {uuid4()}",
|
||||
description="Test description",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider="vendor",
|
||||
|
||||
@ -15,6 +15,7 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory:
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
retrieval_model: str = "old_model",
|
||||
permission: str = "only_me",
|
||||
embedding_model_provider: str | None = None,
|
||||
@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None,
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None,
|
||||
"embedding_model": None,
|
||||
@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "economy",
|
||||
"indexing_technique": IndexTechniqueType.ECONOMY,
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "remove")
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "economy"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
assert dataset.embedding_model is None
|
||||
assert dataset.embedding_model_provider is None
|
||||
assert dataset.collection_binding_id is None
|
||||
@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
|
||||
embedding_model = Mock()
|
||||
@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "add")
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == binding.id
|
||||
@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.collection_binding_id == existing_binding_id
|
||||
@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType, TagType
|
||||
@ -102,7 +103,7 @@ class TestTagService:
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
tenant_id=tenant_id,
|
||||
created_by=mock_external_service_dependencies["current_user"].id,
|
||||
)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -8,14 +11,14 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowAppLogCreatedFrom
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
@ -1525,3 +1528,168 @@ class TestWorkflowAppService:
|
||||
|
||||
# Should not find tenant2's data when searching from tenant1's context
|
||||
assert result_cross_tenant["total"] == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
|
||||
with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"):
|
||||
service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
created_by_account="nonexistent@example.com",
|
||||
)
|
||||
|
||||
def test_get_paginate_workflow_app_logs_filters_by_account(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
created_by_account=account.email,
|
||||
)
|
||||
|
||||
assert result["total"] >= 0
|
||||
assert isinstance(result["data"], list)
|
||||
|
||||
def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="browser",
|
||||
is_anonymous=False,
|
||||
session_id="session-1",
|
||||
)
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
archive_defaults = {
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"run_version": "1.0.0",
|
||||
"run_status": WorkflowExecutionStatus.SUCCEEDED,
|
||||
"run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
"run_error": None,
|
||||
"run_elapsed_time": 1.0,
|
||||
"run_total_tokens": 0,
|
||||
"run_total_steps": 0,
|
||||
"run_created_at": now,
|
||||
"run_finished_at": now,
|
||||
"run_exceptions_count": 0,
|
||||
"trigger_metadata": '{"type":"trigger-webhook"}',
|
||||
"log_created_at": now,
|
||||
"log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
}
|
||||
archive_account = WorkflowArchiveLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
log_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
**archive_defaults,
|
||||
)
|
||||
archive_end_user = WorkflowArchiveLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
log_id=str(uuid.uuid4()),
|
||||
created_by=end_user.id,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
**archive_defaults,
|
||||
)
|
||||
db_session_with_containers.add_all([archive_account, archive_end_user])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = service.get_paginate_workflow_archive_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
page=1,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert result["total"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
account_item = next(d for d in result["data"] if d["created_by_account"] is not None)
|
||||
end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None)
|
||||
assert account_item["created_by_account"].id == account.id
|
||||
assert end_user_item["created_by_end_user"].id == end_user.id
|
||||
|
||||
|
||||
class TestLogView:
|
||||
def test_details_and_proxy_attributes(self):
|
||||
log = SimpleNamespace(id="log-1", status="succeeded")
|
||||
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
|
||||
|
||||
assert view.details == {"trigger_metadata": {"type": "plugin"}}
|
||||
assert view.status == "succeeded"
|
||||
|
||||
|
||||
class TestHandleTriggerMetadata:
|
||||
def test_returns_empty_dict_when_metadata_missing(self):
|
||||
service = WorkflowAppService()
|
||||
assert service.handle_trigger_metadata("tenant-1", None) == {}
|
||||
|
||||
def test_enriches_plugin_icons(self):
|
||||
service = WorkflowAppService()
|
||||
meta = {
|
||||
"type": AppTriggerType.TRIGGER_PLUGIN.value,
|
||||
"icon_filename": "light.png",
|
||||
"icon_dark_filename": "dark.png",
|
||||
}
|
||||
with patch(
|
||||
"services.workflow_app_service.PluginService.get_plugin_icon_url",
|
||||
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
|
||||
) as mock_icon:
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
assert result["icon"] == "https://cdn/light.png"
|
||||
assert result["icon_dark"] == "https://cdn/dark.png"
|
||||
assert mock_icon.call_count == 2
|
||||
|
||||
def test_non_plugin_metadata_without_icon_lookup(self):
|
||||
service = WorkflowAppService()
|
||||
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
|
||||
with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon:
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
|
||||
mock_icon.assert_not_called()
|
||||
|
||||
|
||||
class TestSafeJsonLoads:
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
('{"k":"v"}', {"k": "v"}),
|
||||
("not-json", None),
|
||||
({"raw": True}, {"raw": True}),
|
||||
],
|
||||
)
|
||||
def test_handles_various_inputs(self, value, expected):
|
||||
assert WorkflowAppService._safe_json_loads(value) == expected
|
||||
|
||||
|
||||
class TestSafeParseUuid:
|
||||
def test_returns_none_for_short_or_invalid_values(self):
|
||||
service = WorkflowAppService()
|
||||
assert service._safe_parse_uuid("short") is None
|
||||
assert service._safe_parse_uuid("x" * 40) is None
|
||||
|
||||
def test_returns_uuid_for_valid_string(self):
|
||||
service = WorkflowAppService()
|
||||
raw = str(uuid.uuid4())
|
||||
result = service._safe_parse_uuid(raw)
|
||||
assert result is not None
|
||||
assert str(result) == raw
|
||||
|
||||
@ -1,12 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -786,3 +798,192 @@ class TestToolTransformService:
|
||||
assert result is not None
|
||||
assert result == mock_controller
|
||||
mock_from_db.assert_called_once_with(provider)
|
||||
|
||||
|
||||
def _mock_tool(*, base_params, runtime_params):
|
||||
"""Helper to build a Mock tool with real entity objects.
|
||||
|
||||
Tool is abstract and requires runtime behaviour (fork_tool_runtime,
|
||||
get_runtime_parameters), so it stays as a Mock. Everything else uses
|
||||
real Pydantic instances.
|
||||
"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="test_author",
|
||||
name="test_tool",
|
||||
label=I18nObject(en_US="Test Tool"),
|
||||
provider="test_provider",
|
||||
),
|
||||
parameters=base_params or [],
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US="Test description"),
|
||||
llm="Test description for LLM",
|
||||
),
|
||||
output_schema={},
|
||||
)
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = entity
|
||||
mock_tool.get_runtime_parameters.return_value = runtime_params
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
return mock_tool
|
||||
|
||||
|
||||
def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None):
|
||||
return ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=label or name),
|
||||
human_description=I18nObject(en_US=name),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=form,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertToolEntityToApiEntity:
|
||||
"""Tests for ToolTransformService.convert_tool_entity_to_api_entity."""
|
||||
|
||||
def test_parameter_override(self):
|
||||
base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")]
|
||||
runtime = [_param("param1", label="Runtime 1")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 2
|
||||
assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1"
|
||||
assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2"
|
||||
|
||||
def test_additional_runtime_parameters(self):
|
||||
base = [_param("param1", label="Base 1")]
|
||||
runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert len(result.parameters) == 2
|
||||
names = [p.name for p in result.parameters]
|
||||
assert "param1" in names
|
||||
assert "runtime_only" in names
|
||||
|
||||
def test_non_form_runtime_parameters_excluded(self):
|
||||
base = [_param("param1")]
|
||||
runtime = [
|
||||
_param("param1", label="Runtime 1"),
|
||||
_param("llm_param", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert len(result.parameters) == 1
|
||||
assert result.parameters[0].name == "param1"
|
||||
|
||||
def test_empty_parameters(self):
|
||||
tool = _mock_tool(base_params=[], runtime_params=[])
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_none_parameters(self):
|
||||
tool = _mock_tool(base_params=None, runtime_params=[])
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_parameter_order_preserved(self):
|
||||
base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")]
|
||||
runtime = [_param("p2", label="R2"), _param("p4", label="R4")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"]
|
||||
assert result.parameters[1].label.en_US == "R2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Tests for ToolTransformService.workflow_provider_to_user_provider."""
|
||||
|
||||
@staticmethod
|
||||
def _make_controller(provider_id="provider_123", **identity_overrides):
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
defaults = {
|
||||
"author": "test_author",
|
||||
"name": "test_workflow_tool",
|
||||
"description": I18nObject(en_US="Test description"),
|
||||
"icon": '{"type": "emoji", "content": "🔧"}',
|
||||
"icon_dark": None,
|
||||
"label": I18nObject(en_US="Test Workflow Tool"),
|
||||
}
|
||||
defaults.update(identity_overrides)
|
||||
identity = ToolProviderIdentity(**defaults)
|
||||
entity = ToolProviderEntity(identity=identity)
|
||||
return WorkflowToolProviderController(entity=entity, provider_id=provider_id)
|
||||
|
||||
def test_with_workflow_app_id(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["l1", "l2"],
|
||||
workflow_app_id="app_123",
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "provider_123"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == "app_123"
|
||||
assert result.labels == ["l1", "l2"]
|
||||
assert result.is_team_authorization is True
|
||||
|
||||
def test_without_workflow_app_id(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["l1"],
|
||||
)
|
||||
|
||||
assert result.workflow_app_id is None
|
||||
|
||||
def test_workflow_app_id_none_explicit(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_preserves_other_fields(self):
|
||||
ctrl = self._make_controller(
|
||||
"provider_456",
|
||||
author="another_author",
|
||||
name="another_workflow_tool",
|
||||
description=I18nObject(en_US="Another desc", zh_Hans="Another desc"),
|
||||
icon='{"type": "emoji", "content": "⚙️"}',
|
||||
icon_dark='{"type": "emoji", "content": "🔧"}',
|
||||
label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"),
|
||||
)
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["automation"],
|
||||
workflow_app_id="app_456",
|
||||
)
|
||||
|
||||
assert result.id == "provider_456"
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == "app_456"
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
|
||||
@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -19,7 +19,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -142,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
created_by=account.id,
|
||||
|
||||
@ -18,7 +18,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -154,7 +154,7 @@ class TestCleanDatasetTask:
|
||||
tenant_id=tenant.id,
|
||||
name="test_dataset",
|
||||
description="Test dataset for cleanup testing",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
@ -870,7 +870,7 @@ class TestCleanDatasetTask:
|
||||
tenant_id=tenant.id,
|
||||
name=long_name,
|
||||
description=long_description,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph", "max_length": 10000}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
|
||||
@ -12,7 +12,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -121,7 +121,7 @@ class TestCreateSegmentToIndexTask:
|
||||
description=fake.text(max_nb_chars=100),
|
||||
tenant_id=tenant_id,
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
created_by=account_id,
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = DataSourceType.UPLOAD_FILE
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.index_struct = '{"type": "paragraph"}'
|
||||
dataset.created_by = account.id
|
||||
dataset.created_at = fake.date_time_this_year()
|
||||
|
||||
@ -15,7 +15,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -100,7 +100,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
name=fake.sentence(nb_words=3),
|
||||
description=fake.text(max_nb_chars=200),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
from models.dataset import DatasetProcessRule
|
||||
@ -103,7 +103,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
|
||||
@ -14,7 +14,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
@ -57,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
|
||||
name=f"dataset-{uuid4()}",
|
||||
description="sync test dataset",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
@ -99,7 +100,7 @@ class TestDocumentIndexingTasks:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -181,7 +182,7 @@ class TestDocumentIndexingTasks:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
@ -64,7 +64,7 @@ class TestDocumentIndexingUpdateTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=64),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -110,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -245,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
|
||||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
|
||||
document.error = None
|
||||
document.stopped_at = None
|
||||
|
||||
# First count = completed segments, second = total segments
|
||||
query_mock = MagicMock()
|
||||
query_mock.where.side_effect = [
|
||||
MagicMock(count=lambda: 2),
|
||||
MagicMock(count=lambda: 5),
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
|
||||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=query_mock,
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
side_effect=[2, 5],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
||||
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=10,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=mock_key,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.commit",
|
||||
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
||||
@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@ -21,7 +22,7 @@ class TestParagraphIndexProcessor:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
@ -167,7 +168,7 @@ class TestParagraphIndexProcessor:
|
||||
def test_load_uses_keyword_add_texts_with_keywords_when_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
@ -178,7 +179,7 @@ class TestParagraphIndexProcessor:
|
||||
def test_load_uses_keyword_add_texts_without_keywords_when_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
@ -208,7 +209,7 @@ class TestParagraphIndexProcessor:
|
||||
def test_clean_economy_deletes_summaries_and_keywords(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
|
||||
with (
|
||||
patch(
|
||||
@ -222,7 +223,7 @@ class TestParagraphIndexProcessor:
|
||||
mock_keyword_cls.return_value.delete.assert_called_once()
|
||||
|
||||
def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
processor.clean(dataset, ["node-2"], with_keywords=True)
|
||||
|
||||
@ -267,7 +268,7 @@ class TestParagraphIndexProcessor:
|
||||
def test_index_list_chunks_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
@ -19,7 +20,7 @@ class TestParentChildIndexProcessor:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
|
||||
@ -33,7 +34,7 @@ class TestQAIndexProcessor:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
@ -207,7 +208,7 @@ class TestQAIndexProcessor:
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
|
||||
def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
|
||||
@ -298,7 +299,7 @@ class TestQAIndexProcessor:
|
||||
def test_index_requires_high_quality(
|
||||
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
|
||||
|
||||
with (
|
||||
|
||||
@ -61,7 +61,7 @@ from core.indexing_runner import (
|
||||
DocumentIsPausedError,
|
||||
IndexingRunner,
|
||||
)
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument
|
||||
def create_mock_dataset(
|
||||
dataset_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
) -> Mock:
|
||||
@ -458,7 +458,7 @@ class TestIndexingRunnerTransform:
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid.uuid4())
|
||||
dataset.tenant_id = str(uuid.uuid4())
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
@ -521,7 +521,7 @@ class TestIndexingRunnerTransform:
|
||||
"""Test transformation with economy indexing (no embeddings)."""
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset.indexing_technique = "economy"
|
||||
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
|
||||
mock_processor = MagicMock()
|
||||
transformed_docs = [
|
||||
@ -605,7 +605,7 @@ class TestIndexingRunnerLoad:
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid.uuid4())
|
||||
dataset.tenant_id = str(uuid.uuid4())
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
@ -674,7 +674,7 @@ class TestIndexingRunnerLoad:
|
||||
"""Test loading with economy indexing (keyword only)."""
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset.indexing_technique = "economy"
|
||||
sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
|
||||
mock_processor = MagicMock()
|
||||
|
||||
@ -701,7 +701,7 @@ class TestIndexingRunnerLoad:
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
sample_dataset.indexing_technique = "high_quality"
|
||||
sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
|
||||
# Add child documents
|
||||
for doc in sample_documents:
|
||||
@ -795,7 +795,7 @@ class TestIndexingRunnerRun:
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dataset.id = doc.dataset_id
|
||||
mock_dataset.tenant_id = doc.tenant_id
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
@ -949,7 +949,7 @@ class TestIndexingRunnerRun:
|
||||
mock_dependencies["db"].session.get.side_effect = get_side_effect
|
||||
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
@ -78,7 +79,7 @@ def sample_node_data():
|
||||
type="knowledge-index",
|
||||
chunk_structure="general_structure",
|
||||
index_chunk_variable_selector=["start", "chunks"],
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
summary_index_setting=None,
|
||||
)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
ChildChunk,
|
||||
@ -67,14 +68,14 @@ class TestDatasetModelValidation:
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
description="Test description",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert dataset.description == "Test description"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
|
||||
@ -86,21 +87,21 @@ class TestDatasetModelValidation:
|
||||
name="High Quality Dataset",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
dataset_economy = Dataset(
|
||||
tenant_id=str(uuid4()),
|
||||
name="Economy Dataset",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert dataset_high_quality.indexing_technique == "high_quality"
|
||||
assert dataset_economy.indexing_technique == "economy"
|
||||
assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST
|
||||
assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST
|
||||
assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST
|
||||
assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST
|
||||
|
||||
def test_dataset_provider_validation(self):
|
||||
"""Test dataset provider values."""
|
||||
@ -983,7 +984,7 @@ class TestModelIntegration:
|
||||
name="Test Dataset",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
|
||||
@ -1019,7 +1020,7 @@ class TestModelIntegration:
|
||||
assert document.dataset_id == dataset_id
|
||||
assert segment.dataset_id == dataset_id
|
||||
assert segment.document_id == document_id
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert document.word_count == 100
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
|
||||
|
||||
@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
name: str = "Test Dataset",
|
||||
description: str = "Test description",
|
||||
tenant_id: str = "tenant-123",
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str | None = "openai",
|
||||
embedding_model: str | None = "text-embedding-ada-002",
|
||||
collection_binding_id: str | None = "binding-123",
|
||||
@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
@staticmethod
|
||||
def create_knowledge_configuration_mock(
|
||||
chunk_structure: str = "tree",
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
keyword_number: int = 10,
|
||||
@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure="tree",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
chunk_structure="list",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
|
||||
# Assert
|
||||
assert dataset.chunk_structure == "list"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == "binding-123"
|
||||
@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure="tree", # Existing structure
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
chunk_structure="list", # Different structure
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
mock_session.merge.return_value = dataset
|
||||
@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
indexing_technique="high_quality", # Current technique
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
indexing_technique="economy", # Trying to change to economy
|
||||
indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy
|
||||
)
|
||||
|
||||
mock_session.merge.return_value = dataset
|
||||
|
||||
@ -111,7 +111,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset, DatasetProcessRule, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@ -154,7 +154,7 @@ class DocumentValidationTestDataFactory:
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
doc_form: str | None = None,
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
**kwargs,
|
||||
@ -190,7 +190,7 @@ class DocumentValidationTestDataFactory:
|
||||
data_source: DataSource | None = None,
|
||||
process_rule: ProcessRule | None = None,
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
@ -448,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
@ -481,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
|
||||
- No errors are raised
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
@ -503,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="invalid-model",
|
||||
)
|
||||
@ -533,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.account import Account
|
||||
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentType
|
||||
@ -111,7 +111,7 @@ class SegmentTestDataFactory:
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
embedding_model_provider: str = "openai",
|
||||
**kwargs,
|
||||
@ -163,7 +163,7 @@ class TestSegmentServiceCreateSegment:
|
||||
"""Test successful creation of a segment."""
|
||||
# Arrange
|
||||
document = SegmentTestDataFactory.create_document_mock(word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
args = {"content": "New segment content", "keywords": ["test", "segment"]}
|
||||
|
||||
mock_query = MagicMock()
|
||||
@ -212,7 +212,7 @@ class TestSegmentServiceCreateSegment:
|
||||
"""Test creation of segment with QA model (requires answer)."""
|
||||
# Arrange
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
|
||||
|
||||
mock_query = MagicMock()
|
||||
@ -247,7 +247,7 @@ class TestSegmentServiceCreateSegment:
|
||||
"""Test creation of segment with high quality indexing technique."""
|
||||
# Arrange
|
||||
document = SegmentTestDataFactory.create_document_mock(word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
args = {"content": "New segment content", "keywords": ["test"]}
|
||||
|
||||
mock_query = MagicMock()
|
||||
@ -289,7 +289,7 @@ class TestSegmentServiceCreateSegment:
|
||||
"""Test segment creation when vector indexing fails."""
|
||||
# Arrange
|
||||
document = SegmentTestDataFactory.create_document_mock(word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
args = {"content": "New segment content", "keywords": ["test"]}
|
||||
|
||||
mock_query = MagicMock()
|
||||
@ -342,7 +342,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
# Arrange
|
||||
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
|
||||
document = SegmentTestDataFactory.create_document_mock(word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
args = SegmentUpdateArgs(content="Updated content", keywords=["updated"])
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
@ -431,7 +431,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
# Arrange
|
||||
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
|
||||
@ -1,214 +0,0 @@
|
||||
"""
|
||||
Unit tests for services.advanced_prompt_template_service
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from models.model import AppMode
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class TestAdvancedPromptTemplateService:
|
||||
"""Test suite for AdvancedPromptTemplateService."""
|
||||
|
||||
def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None:
|
||||
"""Test baichuan model names use baichuan context prompt."""
|
||||
# Arrange
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "Baichuan2-13B",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
|
||||
def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None:
|
||||
"""Test non-baichuan model names use common prompt."""
|
||||
# Arrange
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-4",
|
||||
"has_context": "false",
|
||||
}
|
||||
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None:
|
||||
"""Test invalid app mode returns empty dict."""
|
||||
# Arrange
|
||||
app_mode = "invalid"
|
||||
model_mode = "chat"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None:
|
||||
"""Test context is prepended for completion prompt when has_context is true."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT)
|
||||
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None:
|
||||
"""Test context is prepended for chat prompt when has_context is true."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT)
|
||||
assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None:
|
||||
"""Test chat prompt remains unchanged when has_context is false."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None:
|
||||
"""Test completion app mode with completion model returns completion prompt."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None:
|
||||
"""Test invalid model mode returns empty dict."""
|
||||
# Arrange
|
||||
app_mode = AppMode.CHAT
|
||||
model_mode = "invalid"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
|
||||
"""Test helper keeps completion prompt unchanged when context is disabled."""
|
||||
# Arrange
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"] == original_text
|
||||
|
||||
def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
|
||||
"""Test helper keeps chat prompt unchanged when context is disabled."""
|
||||
# Arrange
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text
|
||||
|
||||
def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None:
|
||||
"""Test baichuan chat/completion returns the expected config."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None:
|
||||
"""Test baichuan completion/chat returns the expected config."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None:
|
||||
"""Test baichuan completion/completion prepends baichuan context when enabled."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None:
|
||||
"""Test baichuan chat/chat prepends baichuan context when enabled."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None:
|
||||
"""Test invalid baichuan mode combinations return empty dict."""
|
||||
# Arrange
|
||||
app_mode = "invalid"
|
||||
model_mode = "invalid"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
@ -1,683 +0,0 @@
|
||||
"""Unit tests for services.app_service."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from models import Account, Tenant
|
||||
from models.model import App, AppMode, IconType
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> AppService:
|
||||
"""Provide AppService instance."""
|
||||
return AppService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account() -> Account:
|
||||
"""Create account object for create_app tests."""
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
result = Account(name="Account User", email="account@example.com")
|
||||
result.id = "acc-1"
|
||||
result._current_tenant = tenant
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_args() -> dict:
|
||||
"""Create default create_app args."""
|
||||
return {
|
||||
"name": "Test App",
|
||||
"mode": AppMode.CHAT.value,
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFFFFF",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_template() -> dict:
|
||||
"""Create basic app template for create_app tests."""
|
||||
return {
|
||||
AppMode.CHAT: {
|
||||
"app": {},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "provider-a",
|
||||
"name": "model-a",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_current_user() -> Account:
|
||||
user = Account(name="Tester", email="tester@example.com")
|
||||
user.id = "user-1"
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
user._current_tenant = tenant
|
||||
return user
|
||||
|
||||
|
||||
class TestAppServicePagination:
|
||||
"""Test suite for get_paginate_apps."""
|
||||
|
||||
def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None:
|
||||
"""Test pagination returns None when tag filter has no targets."""
|
||||
# Arrange
|
||||
args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]}
|
||||
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]):
|
||||
# Act
|
||||
result = service.get_paginate_apps("user-1", "tenant-1", args)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None:
|
||||
"""Test pagination delegates to db.paginate when filters are valid."""
|
||||
# Arrange
|
||||
args = {
|
||||
"mode": "workflow",
|
||||
"is_created_by_me": True,
|
||||
"name": "My_App%",
|
||||
"tag_ids": ["tag-1"],
|
||||
"page": 2,
|
||||
"limit": 10,
|
||||
}
|
||||
expected_pagination = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]),
|
||||
patch("libs.helper.escape_like_pattern", return_value="escaped"),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
mock_db.paginate.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_apps("user-1", "tenant-1", args)
|
||||
|
||||
# Assert
|
||||
assert result is expected_pagination
|
||||
mock_db.paginate.assert_called_once()
|
||||
|
||||
|
||||
class TestAppServiceCreate:
|
||||
"""Test suite for create_app."""
|
||||
|
||||
def test_create_app_should_create_with_matching_default_model(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app uses matching default model and persists app config."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
model_instance = SimpleNamespace(
|
||||
model_name="model-a",
|
||||
provider="provider-a",
|
||||
model_type_instance=MagicMock(),
|
||||
credentials={"k": "v"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.app_was_created") as mock_event,
|
||||
patch("services.app_service.FeatureService.get_system_features") as mock_features,
|
||||
patch("services.app_service.BillingService") as mock_billing,
|
||||
patch("services.app_service.dify_config") as mock_config,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_config.BILLING_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
assert app_instance.app_model_config_id == "cfg-1"
|
||||
mock_db.session.add.assert_any_call(app_instance)
|
||||
mock_db.session.add.assert_any_call(app_model_config)
|
||||
assert mock_db.session.flush.call_count == 2
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_event.send.assert_called_once_with(app_instance, account=account)
|
||||
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_create_app_should_raise_when_model_schema_missing(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app raises ValueError when non-matching model has no schema."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1")
|
||||
model_instance = SimpleNamespace(
|
||||
model_name="model-b",
|
||||
provider="provider-b",
|
||||
model_type_instance=MagicMock(),
|
||||
credentials={"k": "v"},
|
||||
)
|
||||
model_instance.model_type_instance.get_model_schema.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="model schema not found"):
|
||||
service.create_app("tenant-1", default_args, account)
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
def test_create_app_should_fallback_to_default_provider_when_model_missing(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app falls back to provider/model name when no default model instance is available."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.app_was_created") as mock_event,
|
||||
patch("services.app_service.FeatureService.get_system_features") as mock_features,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise,
|
||||
patch("services.app_service.dify_config") as mock_config,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready")
|
||||
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_config.BILLING_ENABLED = False
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
mock_event.send.assert_called_once_with(app_instance, account=account)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private")
|
||||
|
||||
def test_create_app_should_log_and_fallback_on_unexpected_model_error(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test unexpected model manager errors are logged and fallback provider is used."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db"),
|
||||
patch("services.app_service.app_was_created"),
|
||||
patch(
|
||||
"services.app_service.FeatureService.get_system_features",
|
||||
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
|
||||
),
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)),
|
||||
patch("services.app_service.logger") as mock_logger,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.side_effect = RuntimeError("boom")
|
||||
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
class TestAppServiceGetAndUpdate:
|
||||
"""Test suite for app retrieval and update methods."""
|
||||
|
||||
def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None:
|
||||
"""Test get_app returns original app for non-agent modes."""
|
||||
# Arrange
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.CHAT
|
||||
app.is_agent = False
|
||||
|
||||
with patch("services.app_service.current_user", _make_current_user()):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result is app
|
||||
|
||||
def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None:
|
||||
"""Test get_app returns app when agent mode has no model config."""
|
||||
# Arrange
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = None
|
||||
|
||||
with patch("services.app_service.current_user", _make_current_user()):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result is app
|
||||
|
||||
def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None:
|
||||
"""Test get_app decrypts and masks secret tool parameters."""
|
||||
# Arrange
|
||||
tool = {
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "provider-1",
|
||||
"tool_name": "tool-a",
|
||||
"tool_parameters": {"secret": "encrypted"},
|
||||
"extra": True,
|
||||
}
|
||||
model_config = MagicMock()
|
||||
model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]}
|
||||
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = model_config
|
||||
|
||||
manager = MagicMock()
|
||||
manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"}
|
||||
manager.mask_tool_parameters.return_value = {"secret": "***"}
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", _make_current_user()),
|
||||
patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()),
|
||||
patch("services.app_service.ToolParameterConfigurationManager", return_value=manager),
|
||||
):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result.app_model_config is model_config
|
||||
assert tool["tool_parameters"] == {"secret": "***"}
|
||||
assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"}
|
||||
|
||||
def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None:
|
||||
"""Test get_app logs and continues when masking fails."""
|
||||
# Arrange
|
||||
tool = {
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "provider-1",
|
||||
"tool_name": "tool-a",
|
||||
"tool_parameters": {"secret": "encrypted"},
|
||||
"extra": True,
|
||||
}
|
||||
model_config = MagicMock()
|
||||
model_config.agent_mode_dict = {"tools": [tool]}
|
||||
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = model_config
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", _make_current_user()),
|
||||
patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")),
|
||||
patch("services.app_service.logger") as mock_logger,
|
||||
):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result.app_model_config is model_config
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None:
|
||||
"""Test update methods set fields and commit changes."""
|
||||
# Arrange
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
name="old",
|
||||
description="old",
|
||||
icon_type="emoji",
|
||||
icon="a",
|
||||
icon_background="#111",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
),
|
||||
)
|
||||
args = {
|
||||
"name": "new",
|
||||
"description": "new-desc",
|
||||
"icon_type": "image",
|
||||
"icon": "new-icon",
|
||||
"icon_background": "#222",
|
||||
"use_icon_as_answer_icon": True,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
user = SimpleNamespace(id="user-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", user),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.naive_utc_now", return_value="now"),
|
||||
):
|
||||
# Act
|
||||
updated = service.update_app(app, args)
|
||||
renamed = service.update_app_name(app, "rename")
|
||||
iconed = service.update_app_icon(app, "icon-2", "#333")
|
||||
site_same = service.update_app_site_status(app, app.enable_site)
|
||||
api_same = service.update_app_api_status(app, app.enable_api)
|
||||
site_changed = service.update_app_site_status(app, False)
|
||||
api_changed = service.update_app_api_status(app, False)
|
||||
|
||||
# Assert
|
||||
assert updated is app
|
||||
assert updated.icon_type == IconType.IMAGE
|
||||
assert renamed is app
|
||||
assert iconed is app
|
||||
assert site_same is app
|
||||
assert api_same is app
|
||||
assert site_changed is app
|
||||
assert api_changed is app
|
||||
assert mock_db.session.commit.call_count >= 5
|
||||
|
||||
def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None:
|
||||
"""Test update_app keeps the existing icon_type when the payload omits it."""
|
||||
# Arrange
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
name="old",
|
||||
description="old",
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="a",
|
||||
icon_background="#111",
|
||||
use_icon_as_answer_icon=False,
|
||||
max_active_requests=1,
|
||||
),
|
||||
)
|
||||
args = {
|
||||
"name": "new",
|
||||
"description": "new-desc",
|
||||
"icon_type": None,
|
||||
"icon": "new-icon",
|
||||
"icon_background": "#222",
|
||||
"use_icon_as_answer_icon": True,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
user = SimpleNamespace(id="user-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", user),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.naive_utc_now", return_value="now"),
|
||||
):
|
||||
# Act
|
||||
updated = service.update_app(app, args)
|
||||
|
||||
# Assert
|
||||
assert updated is app
|
||||
assert updated.icon_type == IconType.EMOJI
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None:
|
||||
"""Test update_app rejects an explicit empty icon_type."""
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
name="old",
|
||||
description="old",
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="a",
|
||||
icon_background="#111",
|
||||
use_icon_as_answer_icon=False,
|
||||
max_active_requests=1,
|
||||
),
|
||||
)
|
||||
args = {
|
||||
"name": "new",
|
||||
"description": "new-desc",
|
||||
"icon_type": "",
|
||||
"icon": "new-icon",
|
||||
"icon_background": "#222",
|
||||
"use_icon_as_answer_icon": True,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
user = SimpleNamespace(id="user-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", user),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
service.update_app(app, args)
|
||||
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestAppServiceDeleteAndMeta:
|
||||
"""Test suite for delete and metadata methods."""
|
||||
|
||||
def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None:
|
||||
"""Test delete_app removes app, runs cleanup, and triggers async deletion task."""
|
||||
# Arrange
|
||||
app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
|
||||
with (
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch(
|
||||
"services.app_service.FeatureService.get_system_features",
|
||||
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)),
|
||||
),
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise,
|
||||
patch(
|
||||
"services.app_service.dify_config",
|
||||
new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"),
|
||||
),
|
||||
patch("services.app_service.BillingService") as mock_billing,
|
||||
patch("services.app_service.remove_app_and_related_data_task") as mock_task,
|
||||
):
|
||||
# Act
|
||||
service.delete_app(app)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_called_once_with(app)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1")
|
||||
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
|
||||
mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1")
|
||||
|
||||
def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None:
|
||||
"""Test get_app_meta extracts builtin and API tool icons from workflow graph."""
|
||||
# Arrange
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "builtin-provider",
|
||||
"tool_name": "tool_builtin",
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"provider_type": "api",
|
||||
"provider_id": "api-provider-id",
|
||||
"tool_name": "tool_api",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
workflow=workflow,
|
||||
app_model_config=None,
|
||||
tenant_id="tenant-1",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"}))
|
||||
|
||||
with (
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = provider
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
meta = service.get_app_meta(app)
|
||||
|
||||
# Assert
|
||||
assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon")
|
||||
assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"}
|
||||
|
||||
def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None:
|
||||
"""Test get_app_meta falls back to default icon when API provider lookup fails."""
|
||||
# Arrange
|
||||
app_model_config = SimpleNamespace(
|
||||
agent_mode_dict={
|
||||
"tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}]
|
||||
}
|
||||
)
|
||||
app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None))
|
||||
|
||||
with (
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
meta = service.get_app_meta(app)
|
||||
|
||||
# Assert
|
||||
assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None:
|
||||
"""Test get_app_meta returns empty metadata when workflow/model config is absent."""
|
||||
# Arrange
|
||||
workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None))
|
||||
chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None))
|
||||
|
||||
# Act
|
||||
workflow_meta = service.get_app_meta(workflow_app)
|
||||
chat_meta = service.get_app_meta(chat_app)
|
||||
|
||||
# Assert
|
||||
assert workflow_meta == {"tool_icons": {}}
|
||||
assert chat_meta == {"tool_icons": {}}
|
||||
|
||||
|
||||
class TestAppServiceCodeLookup:
|
||||
"""Test suite for app code lookup methods."""
|
||||
|
||||
def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None:
|
||||
"""Test get_app_code_by_id raises when site is missing."""
|
||||
# Arrange
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_code_by_id("app-1")
|
||||
|
||||
def test_get_app_code_by_id_should_return_code(self) -> None:
|
||||
"""Test get_app_code_by_id returns site code."""
|
||||
# Arrange
|
||||
site = SimpleNamespace(code="code-1")
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = site
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
result = AppService.get_app_code_by_id("app-1")
|
||||
|
||||
# Assert
|
||||
assert result == "code-1"
|
||||
|
||||
def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None:
|
||||
"""Test get_app_id_by_code raises when code does not exist."""
|
||||
# Arrange
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_id_by_code("missing")
|
||||
|
||||
def test_get_app_id_by_code_should_return_app_id(self) -> None:
|
||||
"""Test get_app_id_by_code returns linked app id."""
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = site
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
result = AppService.get_app_id_by_code("code-1")
|
||||
|
||||
# Assert
|
||||
assert result == "app-1"
|
||||
@ -4,7 +4,7 @@ from unittest.mock import Mock, create_autospec
|
||||
import pytest
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
@ -71,7 +71,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch
|
||||
|
||||
# Minimal knowledge_config stub that satisfies pre-lock code
|
||||
info_list = types.SimpleNamespace(data_source_type="upload_file")
|
||||
@ -80,7 +80,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
original_document_id=None, # go into "new document" branch
|
||||
data_source=data_source,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model=None,
|
||||
embedding_model_provider=None,
|
||||
retrieval_model=None,
|
||||
@ -126,7 +126,7 @@ def test_add_segment_ignores_lock_not_owned(
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch
|
||||
|
||||
document = create_autospec(Document, instance=True)
|
||||
document.id = "doc-1"
|
||||
@ -169,7 +169,7 @@ def test_multi_create_segment_ignores_lock_not_owned(
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.indexing_technique = "economy" # again, skip high_quality path
|
||||
dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path
|
||||
|
||||
document = create_autospec(Document, instance=True)
|
||||
document.id = "doc-1"
|
||||
|
||||
@ -11,7 +11,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
import services.summary_index_service as summary_module
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.enums import SegmentStatus, SummaryStatus
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@ -27,7 +27,7 @@ class _SessionContext:
|
||||
return None
|
||||
|
||||
|
||||
def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock:
|
||||
def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock:
|
||||
dataset = MagicMock(name="dataset")
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
@ -169,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N
|
||||
def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", vector_cls)
|
||||
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy"))
|
||||
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset)
|
||||
vector_cls.assert_not_called()
|
||||
|
||||
|
||||
@ -621,7 +622,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
|
||||
|
||||
|
||||
def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset(indexing_technique="economy")
|
||||
dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
document = MagicMock(spec=summary_module.DatasetDocument)
|
||||
document.id = "doc-1"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
@ -778,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo
|
||||
|
||||
|
||||
def test_enable_summaries_for_segments_skips_non_high_quality() -> None:
|
||||
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy"))
|
||||
SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY))
|
||||
|
||||
|
||||
def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -932,9 +933,8 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon
|
||||
|
||||
|
||||
def test_update_summary_for_segment_skip_conditions() -> None:
|
||||
assert (
|
||||
SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None
|
||||
)
|
||||
economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None
|
||||
seg = _segment(has_document=True)
|
||||
seg.document.doc_form = IndexStructureType.QA_INDEX
|
||||
assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None
|
||||
|
||||
@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
import services.vector_service as vector_service_module
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from services.vector_service import VectorService
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ class _ParentDocStub:
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
@ -192,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider="openai",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
@ -241,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
@ -329,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk
|
||||
def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
@ -348,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch
|
||||
|
||||
|
||||
def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
segment = _make_segment()
|
||||
|
||||
vector_instance = MagicMock()
|
||||
@ -364,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
@ -380,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
@ -473,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.content = "child"
|
||||
child_chunk.index_node_id = "id"
|
||||
@ -489,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
|
||||
@ -505,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch)
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
new_chunk = MagicMock()
|
||||
new_chunk.content = "n"
|
||||
@ -536,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
VectorService.update_child_chunk_vector([], [], [], dataset)
|
||||
@ -561,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
@ -575,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
@ -591,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt
|
||||
def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}])
|
||||
|
||||
vector_instance = MagicMock(name="vector_instance")
|
||||
@ -612,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
|
||||
|
||||
|
||||
def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
@ -630,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch
|
||||
def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
@ -663,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
|
||||
def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
@ -683,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops
|
||||
|
||||
|
||||
def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
|
||||
@ -1,379 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from models import Account, AccountStatus
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
|
||||
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
|
||||
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.webapp_auth_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountLoginError, match="Account is banned"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("password_value", [None, "hash"])
|
||||
def test_authenticate_should_raise_password_error_when_password_is_invalid(
|
||||
password_value: str | None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
|
||||
WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
|
||||
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.authenticate("user@example.com", "pwd")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="u@example.com")
|
||||
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.login(account)
|
||||
|
||||
# Assert
|
||||
assert result == "jwt-token"
|
||||
mock_get_token.assert_called_once_with(account=account)
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("missing@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.BANNED)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(Unauthorized, match="Account is banned"):
|
||||
WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
|
||||
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = SimpleNamespace(status=AccountStatus.ACTIVE)
|
||||
mocker.patch(
|
||||
ACCOUNT_LOOKUP_PATH,
|
||||
return_value=account,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_user_through_email("user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result is account
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Email must be provided"):
|
||||
WebAppAuthService.send_email_code_login_email(account=None, email=None)
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
account = _account(email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
|
||||
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
|
||||
|
||||
# Assert
|
||||
assert result == "token-1"
|
||||
mock_generate_token.assert_called_once()
|
||||
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
|
||||
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
|
||||
|
||||
|
||||
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
|
||||
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
|
||||
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
|
||||
|
||||
# Assert
|
||||
assert result == "token-2"
|
||||
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
|
||||
|
||||
|
||||
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_email_code_login_data("token-abc")
|
||||
|
||||
# Assert
|
||||
assert result == {"code": "123"}
|
||||
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
|
||||
|
||||
|
||||
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
|
||||
|
||||
# Act
|
||||
WebAppAuthService.revoke_email_code_login_token("token-xyz")
|
||||
|
||||
# Assert
|
||||
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="Site not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_query = MagicMock()
|
||||
app_query.where.return_value.first.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(NotFound, match="App not found"):
|
||||
WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
|
||||
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.app_id == "app-1"
|
||||
assert result.session_id == "user@example.com"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
account = _account(id="a1", email="user@example.com")
|
||||
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
|
||||
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
|
||||
|
||||
# Act
|
||||
token = WebAppAuthService._get_account_jwt_token(account)
|
||||
|
||||
# Assert
|
||||
assert token == "jwt-1"
|
||||
payload = mock_issue.call_args.args[0]
|
||||
assert payload["user_id"] == "a1"
|
||||
assert payload["session_id"] == "user@example.com"
|
||||
assert payload["token_source"] == "webapp_login_token"
|
||||
assert payload["auth_type"] == "internal"
|
||||
assert payload["exp"] > int(datetime.now(UTC).timestamp())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("private", True),
|
||||
("private_all", True),
|
||||
("public", False),
|
||||
],
|
||||
)
|
||||
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
|
||||
access_mode: str,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
|
||||
WebAppAuthService.is_app_require_permission_check()
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="App ID could not be determined"):
|
||||
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="public"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "expected"),
|
||||
[
|
||||
("public", WebAppAuthType.PUBLIC),
|
||||
("private", WebAppAuthType.INTERNAL),
|
||||
("private_all", WebAppAuthType.INTERNAL),
|
||||
("sso_verified", WebAppAuthType.EXTERNAL),
|
||||
],
|
||||
)
|
||||
def test_get_app_auth_type_should_map_access_modes_correctly(
|
||||
access_mode: str,
|
||||
expected: WebAppAuthType,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
|
||||
mocker.patch(
|
||||
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
|
||||
return_value=SimpleNamespace(access_mode="private_all"),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
|
||||
|
||||
# Assert
|
||||
assert result == WebAppAuthType.INTERNAL
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
|
||||
WebAppAuthService.get_app_auth_type()
|
||||
|
||||
|
||||
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
|
||||
# Arrange
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Could not determine app authentication type"):
|
||||
WebAppAuthService.get_app_auth_type(access_mode="unknown")
|
||||
@ -1,300 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import App, WorkflowAppLog
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> WorkflowAppService:
|
||||
# Arrange
|
||||
return WorkflowAppService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
# Arrange
|
||||
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
|
||||
|
||||
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
|
||||
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
|
||||
# Arrange
|
||||
log = _workflow_app_log(id="log-1", status="succeeded")
|
||||
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
|
||||
|
||||
# Act
|
||||
details = view.details
|
||||
proxied_status = view.status
|
||||
|
||||
# Assert
|
||||
assert details == {"trigger_metadata": {"type": "plugin"}}
|
||||
assert proxied_status == "succeeded"
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
log_2 = SimpleNamespace(id="log-2")
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.return_value.all.return_value = [log_1, log_2]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=2,
|
||||
detail=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["page"] == 1
|
||||
assert result["limit"] == 2
|
||||
assert result["total"] == 3
|
||||
assert result["has_more"] is True
|
||||
assert len(result["data"]) == 2
|
||||
assert isinstance(result["data"][0], LogView)
|
||||
assert result["data"][0].details is None
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [1]
|
||||
log_1 = SimpleNamespace(id="log-1")
|
||||
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
|
||||
mock_handle = mocker.patch.object(
|
||||
service,
|
||||
"handle_trigger_metadata",
|
||||
return_value={"type": "trigger_plugin", "icon": "url"},
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword="run-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_before=None,
|
||||
created_at_after=None,
|
||||
page=1,
|
||||
limit=20,
|
||||
detail=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 1
|
||||
assert len(result["data"]) == 1
|
||||
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Account not found: account@example.com"):
|
||||
service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
created_by_account="account@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert result["data"] == []
|
||||
|
||||
|
||||
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
|
||||
service: WorkflowAppService,
|
||||
app_model: App,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
log_account = SimpleNamespace(
|
||||
id="log-1",
|
||||
created_by="acc-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
workflow_run_summary={"run": "1"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-01",
|
||||
)
|
||||
log_end_user = SimpleNamespace(
|
||||
id="log-2",
|
||||
created_by="end-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
workflow_run_summary={"run": "2"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-02",
|
||||
)
|
||||
log_unknown = SimpleNamespace(
|
||||
id="log-3",
|
||||
created_by="other",
|
||||
created_by_role="system",
|
||||
workflow_run_summary={"run": "3"},
|
||||
trigger_metadata='{"type":"trigger-webhook"}',
|
||||
log_created_at="2026-01-03",
|
||||
)
|
||||
session.scalar.return_value = 3
|
||||
session.scalars.side_effect = [
|
||||
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
|
||||
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=1,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert result["data"][0]["created_by_account"].id == "acc-1"
|
||||
assert result["data"][1]["created_by_end_user"].id == "end-1"
|
||||
assert result["data"][2]["created_by_account"] is None
|
||||
assert result["data"][2]["created_by_end_user"] is None
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", None)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {
|
||||
"type": AppTriggerType.TRIGGER_PLUGIN.value,
|
||||
"icon_filename": "light.png",
|
||||
"icon_dark_filename": "dark.png",
|
||||
}
|
||||
mock_icon = mocker.patch(
|
||||
"services.workflow_app_service.PluginService.get_plugin_icon_url",
|
||||
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["icon"] == "https://cdn/light.png"
|
||||
assert result["icon_dark"] == "https://cdn/dark.png"
|
||||
assert mock_icon.call_count == 2
|
||||
|
||||
|
||||
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
|
||||
service: WorkflowAppService,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
|
||||
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
|
||||
|
||||
# Act
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
# Assert
|
||||
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
|
||||
mock_icon.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
('{"k":"v"}', {"k": "v"}),
|
||||
("not-json", None),
|
||||
({"raw": True}, {"raw": True}),
|
||||
],
|
||||
)
|
||||
def test_safe_json_loads_should_handle_various_inputs(
|
||||
value: object,
|
||||
expected: object,
|
||||
service: WorkflowAppService,
|
||||
) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
result = service._safe_json_loads(value)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
# Act
|
||||
short_result = service._safe_parse_uuid("short")
|
||||
invalid_result = service._safe_parse_uuid("x" * 40)
|
||||
|
||||
# Assert
|
||||
assert short_result is None
|
||||
assert invalid_result is None
|
||||
|
||||
|
||||
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
|
||||
# Arrange
|
||||
raw_uuid = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
result = service._safe_parse_uuid(raw_uuid)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert str(result) == raw_uuid
|
||||
@ -1,452 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
@ -121,7 +121,7 @@ import pytest
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
|
||||
from services.vector_service import VectorService
|
||||
@ -153,7 +153,7 @@ class VectorServiceTestDataFactory:
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
index_struct_dict: dict | None = None,
|
||||
@ -494,7 +494,7 @@ class TestVectorService:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality"
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
@ -535,7 +535,7 @@ class TestVectorService:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form="parent_child_model", indexing_technique="high_quality"
|
||||
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
@ -568,7 +568,7 @@ class TestVectorService:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form="parent_child_model", indexing_technique="high_quality"
|
||||
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
@ -591,7 +591,7 @@ class TestVectorService:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form="parent_child_model", indexing_technique="high_quality"
|
||||
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
@ -616,7 +616,7 @@ class TestVectorService:
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form="parent_child_model", indexing_technique="economy"
|
||||
doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
@ -669,7 +669,7 @@ class TestVectorService:
|
||||
store when using high_quality indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
|
||||
@ -695,7 +695,7 @@ class TestVectorService:
|
||||
index when using economy indexing with keywords.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
|
||||
@ -731,7 +731,7 @@ class TestVectorService:
|
||||
index when using economy indexing without keywords.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
|
||||
@ -895,7 +895,7 @@ class TestVectorService:
|
||||
when using high_quality indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -923,7 +923,7 @@ class TestVectorService:
|
||||
using economy indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -951,7 +951,7 @@ class TestVectorService:
|
||||
when there are new chunks, updated chunks, and deleted chunks.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1")
|
||||
|
||||
@ -993,7 +993,7 @@ class TestVectorService:
|
||||
add_texts is called, not delete_by_ids.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -1019,7 +1019,7 @@ class TestVectorService:
|
||||
delete_by_ids is called, not add_texts.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -1045,7 +1045,7 @@ class TestVectorService:
|
||||
using economy indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -1075,7 +1075,7 @@ class TestVectorService:
|
||||
when using high_quality indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY)
|
||||
|
||||
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
@ -1099,7 +1099,7 @@ class TestVectorService:
|
||||
using economy indexing.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY)
|
||||
|
||||
child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock()
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.enums import DataSourceType
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
|
||||
@ -184,7 +184,7 @@ class TestErrorHandling:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -229,7 +229,7 @@ class TestPipelineAndWorkflowDeletion:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -265,7 +265,7 @@ class TestPipelineAndWorkflowDeletion:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -321,7 +321,7 @@ class TestSegmentAttachmentCleanup:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -366,7 +366,7 @@ class TestSegmentAttachmentCleanup:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -408,7 +408,7 @@ class TestEdgeCases:
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
@ -445,7 +445,7 @@ class TestIndexProcessorParameters:
|
||||
- Dataset object with correct attributes is passed
|
||||
"""
|
||||
# Arrange
|
||||
indexing_technique = "high_quality"
|
||||
indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
index_struct = '{"type": "paragraph"}'
|
||||
|
||||
# Act
|
||||
|
||||
@ -15,7 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -209,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id):
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
@ -49,9 +49,12 @@ vi.mock('@/service/use-tools', () => ({
|
||||
|
||||
// Mock Toast - need to verify notification calls
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (options: { type: string, message: string }) => mockToastNotify(options),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: (message: string) => mockToastNotify({ type: 'success', message }),
|
||||
error: (message: string) => mockToastNotify({ type: 'error', message }),
|
||||
warning: (message: string) => mockToastNotify({ type: 'warning', message }),
|
||||
info: (message: string) => mockToastNotify({ type: 'info', message }),
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
@ -33,9 +33,12 @@ vi.mock('@/service/use-tools', () => ({
|
||||
}))
|
||||
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (options: { type: string, message: string }) => mockToastNotify(options),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: (message: string) => mockToastNotify({ type: 'success', message }),
|
||||
error: (message: string) => mockToastNotify({ type: 'error', message }),
|
||||
warning: (message: string) => mockToastNotify({ type: 'warning', message }),
|
||||
info: (message: string) => mockToastNotify({ type: 'info', message }),
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import type { PublishWorkflowParams } from '@/types/workflow'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { createWorkflowToolProvider, saveWorkflowToolProvider } from '@/service/tools'
|
||||
@ -188,14 +188,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) {
|
||||
invalidateAllWorkflowTools()
|
||||
onRefreshData?.()
|
||||
invalidateDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
toast.success(t('api.actionSuccess', { ns: 'common' }))
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
toast.error((e as Error).message)
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,14 +206,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) {
|
||||
onRefreshData?.()
|
||||
invalidateAllWorkflowTools()
|
||||
invalidateDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
toast.success(t('api.actionSuccess', { ns: 'common' }))
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
toast.error((e as Error).message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@ import Drawer from '@/app/components/base/drawer-plus'
|
||||
import EmojiPicker from '@/app/components/base/emoji-picker'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import LabelSelector from '@/app/components/tools/labels/selector'
|
||||
import ConfirmModal from '@/app/components/tools/workflow-tool/confirm-modal'
|
||||
import MethodSelector from '@/app/components/tools/workflow-tool/method-selector'
|
||||
@ -129,10 +129,7 @@ const WorkflowToolAsModal: FC<Props> = ({
|
||||
errorMessage = t('createTool.nameForToolCall', { ns: 'tools' }) + t('createTool.nameForToolCallTip', { ns: 'tools' })
|
||||
|
||||
if (errorMessage) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
toast.error(errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,260 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import CandidateNodeMain from '../candidate-node-main'
|
||||
import { CUSTOM_NODE } from '../constants'
|
||||
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
|
||||
import { BlockEnum } from '../types'
|
||||
import { createNode } from './fixtures'
|
||||
|
||||
const mockUseEventListener = vi.hoisted(() => vi.fn())
|
||||
const mockUseStoreApi = vi.hoisted(() => vi.fn())
|
||||
const mockUseReactFlow = vi.hoisted(() => vi.fn())
|
||||
const mockUseViewport = vi.hoisted(() => vi.fn())
|
||||
const mockUseStore = vi.hoisted(() => vi.fn())
|
||||
const mockUseWorkflowStore = vi.hoisted(() => vi.fn())
|
||||
const mockUseHooks = vi.hoisted(() => vi.fn())
|
||||
const mockCustomNode = vi.hoisted(() => vi.fn())
|
||||
const mockCustomNoteNode = vi.hoisted(() => vi.fn())
|
||||
const mockGetIterationStartNode = vi.hoisted(() => vi.fn())
|
||||
const mockGetLoopStartNode = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useEventListener: (...args: unknown[]) => mockUseEventListener(...args),
|
||||
}))
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: () => mockUseStoreApi(),
|
||||
useReactFlow: () => mockUseReactFlow(),
|
||||
useViewport: () => mockUseViewport(),
|
||||
Position: {
|
||||
Left: 'left',
|
||||
Right: 'right',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: { mousePosition: {
|
||||
pageX: number
|
||||
pageY: number
|
||||
elementX: number
|
||||
elementY: number
|
||||
} }) => unknown) => mockUseStore(selector),
|
||||
useWorkflowStore: () => mockUseWorkflowStore(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useNodesInteractions: () => mockUseHooks().useNodesInteractions(),
|
||||
useNodesSyncDraft: () => mockUseHooks().useNodesSyncDraft(),
|
||||
useWorkflowHistory: () => mockUseHooks().useWorkflowHistory(),
|
||||
useAutoGenerateWebhookUrl: () => mockUseHooks().useAutoGenerateWebhookUrl(),
|
||||
WorkflowHistoryEvent: {
|
||||
NodeAdd: 'NodeAdd',
|
||||
NoteAdd: 'NoteAdd',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes', () => ({
|
||||
__esModule: true,
|
||||
default: (props: { id: string }) => {
|
||||
mockCustomNode(props)
|
||||
return <div data-testid="candidate-custom-node">{props.id}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/note-node', () => ({
|
||||
__esModule: true,
|
||||
default: (props: { id: string }) => {
|
||||
mockCustomNoteNode(props)
|
||||
return <div data-testid="candidate-note-node">{props.id}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils', () => ({
|
||||
getIterationStartNode: (...args: unknown[]) => mockGetIterationStartNode(...args),
|
||||
getLoopStartNode: (...args: unknown[]) => mockGetLoopStartNode(...args),
|
||||
}))
|
||||
|
||||
describe('CandidateNodeMain', () => {
|
||||
const mockSetNodes = vi.fn()
|
||||
const mockHandleNodeSelect = vi.fn()
|
||||
const mockSaveStateToHistory = vi.fn()
|
||||
const mockHandleSyncWorkflowDraft = vi.fn()
|
||||
const mockAutoGenerateWebhookUrl = vi.fn()
|
||||
const mockWorkflowStoreSetState = vi.fn()
|
||||
const createNodesInteractions = () => ({
|
||||
handleNodeSelect: mockHandleNodeSelect,
|
||||
})
|
||||
const createWorkflowHistory = () => ({
|
||||
saveStateToHistory: mockSaveStateToHistory,
|
||||
})
|
||||
const createNodesSyncDraft = () => ({
|
||||
handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft,
|
||||
})
|
||||
const createAutoGenerateWebhookUrl = () => mockAutoGenerateWebhookUrl
|
||||
const eventHandlers: Partial<Record<'click' | 'contextmenu', (event: { preventDefault: () => void }) => void>> = {}
|
||||
let nodes = [createNode({ id: 'existing-node' })]
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
nodes = [createNode({ id: 'existing-node' })]
|
||||
eventHandlers.click = undefined
|
||||
eventHandlers.contextmenu = undefined
|
||||
|
||||
mockUseEventListener.mockImplementation((event: 'click' | 'contextmenu', handler: (event: { preventDefault: () => void }) => void) => {
|
||||
eventHandlers[event] = handler
|
||||
})
|
||||
mockUseStoreApi.mockReturnValue({
|
||||
getState: () => ({
|
||||
getNodes: () => nodes,
|
||||
setNodes: mockSetNodes,
|
||||
}),
|
||||
})
|
||||
mockUseReactFlow.mockReturnValue({
|
||||
screenToFlowPosition: ({ x, y }: { x: number, y: number }) => ({ x: x + 10, y: y + 20 }),
|
||||
})
|
||||
mockUseViewport.mockReturnValue({ zoom: 1.5 })
|
||||
mockUseStore.mockImplementation((selector: (state: { mousePosition: {
|
||||
pageX: number
|
||||
pageY: number
|
||||
elementX: number
|
||||
elementY: number
|
||||
} }) => unknown) => selector({
|
||||
mousePosition: {
|
||||
pageX: 100,
|
||||
pageY: 200,
|
||||
elementX: 30,
|
||||
elementY: 40,
|
||||
},
|
||||
}))
|
||||
mockUseWorkflowStore.mockReturnValue({
|
||||
setState: mockWorkflowStoreSetState,
|
||||
})
|
||||
mockUseHooks.mockReturnValue({
|
||||
useNodesInteractions: createNodesInteractions,
|
||||
useWorkflowHistory: createWorkflowHistory,
|
||||
useNodesSyncDraft: createNodesSyncDraft,
|
||||
useAutoGenerateWebhookUrl: createAutoGenerateWebhookUrl,
|
||||
})
|
||||
mockHandleSyncWorkflowDraft.mockImplementation((_isSync: boolean, _force: boolean, options?: { onSuccess?: () => void }) => {
|
||||
options?.onSuccess?.()
|
||||
})
|
||||
mockGetIterationStartNode.mockReturnValue(createNode({ id: 'iteration-start' }))
|
||||
mockGetLoopStartNode.mockReturnValue(createNode({ id: 'loop-start' }))
|
||||
})
|
||||
|
||||
it('should render the candidate node and commit a webhook node on click', () => {
|
||||
const candidateNode = createNode({
|
||||
id: 'candidate-webhook',
|
||||
type: CUSTOM_NODE,
|
||||
data: {
|
||||
type: BlockEnum.TriggerWebhook,
|
||||
title: 'Webhook Candidate',
|
||||
_isCandidate: true,
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<CandidateNodeMain candidateNode={candidateNode} />)
|
||||
|
||||
expect(screen.getByTestId('candidate-custom-node')).toHaveTextContent('candidate-webhook')
|
||||
expect(container.firstChild).toHaveStyle({
|
||||
left: '30px',
|
||||
top: '40px',
|
||||
transform: 'scale(1.5)',
|
||||
})
|
||||
|
||||
eventHandlers.click?.({ preventDefault: vi.fn() })
|
||||
|
||||
expect(mockSetNodes).toHaveBeenCalledWith(expect.arrayContaining([
|
||||
expect.objectContaining({ id: 'existing-node' }),
|
||||
expect.objectContaining({
|
||||
id: 'candidate-webhook',
|
||||
position: { x: 110, y: 220 },
|
||||
data: expect.objectContaining({ _isCandidate: false }),
|
||||
}),
|
||||
]))
|
||||
expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodeAdd', { nodeId: 'candidate-webhook' })
|
||||
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined })
|
||||
expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledWith(true, true, expect.objectContaining({
|
||||
onSuccess: expect.any(Function),
|
||||
}))
|
||||
expect(mockAutoGenerateWebhookUrl).toHaveBeenCalledWith('candidate-webhook')
|
||||
expect(mockHandleNodeSelect).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should save note candidates as notes and select the inserted note', () => {
|
||||
const candidateNode = createNode({
|
||||
id: 'candidate-note',
|
||||
type: CUSTOM_NOTE_NODE,
|
||||
data: {
|
||||
type: BlockEnum.Code,
|
||||
title: 'Note Candidate',
|
||||
_isCandidate: true,
|
||||
},
|
||||
})
|
||||
|
||||
render(<CandidateNodeMain candidateNode={candidateNode} />)
|
||||
|
||||
expect(screen.getByTestId('candidate-note-node')).toHaveTextContent('candidate-note')
|
||||
|
||||
eventHandlers.click?.({ preventDefault: vi.fn() })
|
||||
|
||||
expect(mockSaveStateToHistory).toHaveBeenCalledWith('NoteAdd', { nodeId: 'candidate-note' })
|
||||
expect(mockHandleNodeSelect).toHaveBeenCalledWith('candidate-note')
|
||||
})
|
||||
|
||||
it('should append iteration and loop start helper nodes for control-flow candidates', () => {
|
||||
const iterationNode = createNode({
|
||||
id: 'candidate-iteration',
|
||||
type: CUSTOM_NODE,
|
||||
data: {
|
||||
type: BlockEnum.Iteration,
|
||||
title: 'Iteration Candidate',
|
||||
_isCandidate: true,
|
||||
},
|
||||
})
|
||||
const loopNode = createNode({
|
||||
id: 'candidate-loop',
|
||||
type: CUSTOM_NODE,
|
||||
data: {
|
||||
type: BlockEnum.Loop,
|
||||
title: 'Loop Candidate',
|
||||
_isCandidate: true,
|
||||
},
|
||||
})
|
||||
|
||||
const { rerender } = render(<CandidateNodeMain candidateNode={iterationNode} />)
|
||||
|
||||
eventHandlers.click?.({ preventDefault: vi.fn() })
|
||||
expect(mockGetIterationStartNode).toHaveBeenCalledWith('candidate-iteration')
|
||||
expect(mockSetNodes.mock.calls[0][0]).toEqual(expect.arrayContaining([
|
||||
expect.objectContaining({ id: 'candidate-iteration' }),
|
||||
expect.objectContaining({ id: 'iteration-start' }),
|
||||
]))
|
||||
|
||||
rerender(<CandidateNodeMain candidateNode={loopNode} />)
|
||||
eventHandlers.click?.({ preventDefault: vi.fn() })
|
||||
|
||||
expect(mockGetLoopStartNode).toHaveBeenCalledWith('candidate-loop')
|
||||
expect(mockSetNodes.mock.calls[1][0]).toEqual(expect.arrayContaining([
|
||||
expect.objectContaining({ id: 'candidate-loop' }),
|
||||
expect.objectContaining({ id: 'loop-start' }),
|
||||
]))
|
||||
})
|
||||
|
||||
it('should clear the candidate node on contextmenu', () => {
|
||||
const candidateNode = createNode({
|
||||
id: 'candidate-context',
|
||||
type: CUSTOM_NODE,
|
||||
data: {
|
||||
type: BlockEnum.Code,
|
||||
title: 'Context Candidate',
|
||||
_isCandidate: true,
|
||||
},
|
||||
})
|
||||
|
||||
render(<CandidateNodeMain candidateNode={candidateNode} />)
|
||||
|
||||
eventHandlers.contextmenu?.({ preventDefault: vi.fn() })
|
||||
|
||||
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined })
|
||||
})
|
||||
})
|
||||
235
web/app/components/workflow/__tests__/custom-edge.spec.tsx
Normal file
235
web/app/components/workflow/__tests__/custom-edge.spec.tsx
Normal file
@ -0,0 +1,235 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { Position } from 'reactflow'
|
||||
import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types'
|
||||
import CustomEdge from '../custom-edge'
|
||||
import { BlockEnum, NodeRunningStatus } from '../types'
|
||||
|
||||
const mockUseAvailableBlocks = vi.hoisted(() => vi.fn())
|
||||
const mockUseNodesInteractions = vi.hoisted(() => vi.fn())
|
||||
const mockBlockSelector = vi.hoisted(() => vi.fn())
|
||||
const mockGradientRender = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
BaseEdge: (props: {
|
||||
id: string
|
||||
path: string
|
||||
style: {
|
||||
stroke: string
|
||||
strokeWidth: number
|
||||
opacity: number
|
||||
strokeDasharray?: string
|
||||
}
|
||||
}) => (
|
||||
<div
|
||||
data-testid="base-edge"
|
||||
data-id={props.id}
|
||||
data-path={props.path}
|
||||
data-stroke={props.style.stroke}
|
||||
data-stroke-width={props.style.strokeWidth}
|
||||
data-opacity={props.style.opacity}
|
||||
data-dasharray={props.style.strokeDasharray}
|
||||
/>
|
||||
),
|
||||
EdgeLabelRenderer: ({ children }: { children?: ReactNode }) => <div data-testid="edge-label">{children}</div>,
|
||||
getBezierPath: () => ['M 0 0', 24, 48],
|
||||
Position: {
|
||||
Right: 'right',
|
||||
Left: 'left',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useAvailableBlocks: (...args: unknown[]) => mockUseAvailableBlocks(...args),
|
||||
useNodesInteractions: () => mockUseNodesInteractions(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/block-selector', () => ({
|
||||
__esModule: true,
|
||||
default: (props: {
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
onSelect: (nodeType: string, pluginDefaultValue?: Record<string, unknown>) => void
|
||||
availableBlocksTypes: string[]
|
||||
triggerClassName?: () => string
|
||||
}) => {
|
||||
mockBlockSelector(props)
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
data-testid="block-selector"
|
||||
data-trigger-class={props.triggerClassName?.()}
|
||||
onClick={() => {
|
||||
props.onOpenChange(true)
|
||||
props.onSelect('llm', { provider: 'openai' })
|
||||
}}
|
||||
>
|
||||
{props.availableBlocksTypes.join(',')}
|
||||
</button>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/custom-edge-linear-gradient-render', () => ({
|
||||
__esModule: true,
|
||||
default: (props: {
|
||||
id: string
|
||||
startColor: string
|
||||
stopColor: string
|
||||
}) => {
|
||||
mockGradientRender(props)
|
||||
return <div data-testid="edge-gradient">{props.id}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
describe('CustomEdge', () => {
|
||||
const mockHandleNodeAdd = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseNodesInteractions.mockReturnValue({
|
||||
handleNodeAdd: mockHandleNodeAdd,
|
||||
})
|
||||
mockUseAvailableBlocks.mockImplementation((nodeType: BlockEnum) => {
|
||||
if (nodeType === BlockEnum.Code)
|
||||
return { availablePrevBlocks: ['code', 'llm'] }
|
||||
|
||||
return { availableNextBlocks: ['llm', 'tool'] }
|
||||
})
|
||||
})
|
||||
|
||||
it('should render a gradient edge and insert a node between the source and target', () => {
|
||||
render(
|
||||
<CustomEdge
|
||||
id="edge-1"
|
||||
source="source-node"
|
||||
sourceHandleId="source"
|
||||
target="target-node"
|
||||
targetHandleId="target"
|
||||
sourceX={100}
|
||||
sourceY={120}
|
||||
sourcePosition={Position.Right}
|
||||
targetX={300}
|
||||
targetY={220}
|
||||
targetPosition={Position.Left}
|
||||
selected={false}
|
||||
data={{
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.Code,
|
||||
_sourceRunningStatus: NodeRunningStatus.Succeeded,
|
||||
_targetRunningStatus: NodeRunningStatus.Failed,
|
||||
_hovering: true,
|
||||
_waitingRun: true,
|
||||
_dimmed: true,
|
||||
_isTemp: true,
|
||||
isInIteration: true,
|
||||
isInLoop: true,
|
||||
} as never}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('edge-gradient')).toHaveTextContent('edge-1')
|
||||
expect(mockGradientRender).toHaveBeenCalledWith(expect.objectContaining({
|
||||
id: 'edge-1',
|
||||
startColor: 'var(--color-workflow-link-line-success-handle)',
|
||||
stopColor: 'var(--color-workflow-link-line-error-handle)',
|
||||
}))
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'url(#edge-1)')
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-opacity', '0.3')
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-dasharray', '8 8')
|
||||
expect(screen.getByTestId('block-selector')).toHaveTextContent('llm')
|
||||
expect(screen.getByTestId('block-selector').parentElement).toHaveStyle({
|
||||
transform: 'translate(-50%, -50%) translate(24px, 48px)',
|
||||
opacity: '0.7',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('block-selector'))
|
||||
|
||||
expect(mockHandleNodeAdd).toHaveBeenCalledWith(
|
||||
{
|
||||
nodeType: 'llm',
|
||||
pluginDefaultValue: { provider: 'openai' },
|
||||
},
|
||||
{
|
||||
prevNodeId: 'source-node',
|
||||
prevNodeSourceHandle: 'source',
|
||||
nextNodeId: 'target-node',
|
||||
nextNodeTargetHandle: 'target',
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
it('should prefer the running stroke color when the edge is selected', () => {
|
||||
render(
|
||||
<CustomEdge
|
||||
id="edge-selected"
|
||||
source="source-node"
|
||||
target="target-node"
|
||||
sourceX={0}
|
||||
sourceY={0}
|
||||
sourcePosition={Position.Right}
|
||||
targetX={100}
|
||||
targetY={100}
|
||||
targetPosition={Position.Left}
|
||||
selected
|
||||
data={{
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.Code,
|
||||
_sourceRunningStatus: NodeRunningStatus.Succeeded,
|
||||
_targetRunningStatus: NodeRunningStatus.Running,
|
||||
} as never}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-handle)')
|
||||
})
|
||||
|
||||
it('should use the fail-branch running color while the connected node is hovering', () => {
|
||||
render(
|
||||
<CustomEdge
|
||||
id="edge-hover"
|
||||
source="source-node"
|
||||
sourceHandleId={ErrorHandleTypeEnum.failBranch}
|
||||
target="target-node"
|
||||
sourceX={0}
|
||||
sourceY={0}
|
||||
sourcePosition={Position.Right}
|
||||
targetX={100}
|
||||
targetY={100}
|
||||
targetPosition={Position.Left}
|
||||
selected={false}
|
||||
data={{
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.Code,
|
||||
_connectedNodeIsHovering: true,
|
||||
} as never}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-failure-handle)')
|
||||
})
|
||||
|
||||
it('should fall back to the default edge color when no highlight state is active', () => {
|
||||
render(
|
||||
<CustomEdge
|
||||
id="edge-default"
|
||||
source="source-node"
|
||||
target="target-node"
|
||||
sourceX={0}
|
||||
sourceY={0}
|
||||
sourcePosition={Position.Right}
|
||||
targetX={100}
|
||||
targetY={100}
|
||||
targetPosition={Position.Left}
|
||||
selected={false}
|
||||
data={{
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.Code,
|
||||
} as never}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-normal)')
|
||||
expect(screen.getByTestId('block-selector')).toHaveAttribute('data-trigger-class', 'hover:scale-150 transition-all')
|
||||
})
|
||||
})
|
||||
114
web/app/components/workflow/__tests__/node-contextmenu.spec.tsx
Normal file
114
web/app/components/workflow/__tests__/node-contextmenu.spec.tsx
Normal file
@ -0,0 +1,114 @@
|
||||
import type { Node } from '../types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import NodeContextmenu from '../node-contextmenu'
|
||||
|
||||
const mockUseClickAway = vi.hoisted(() => vi.fn())
|
||||
const mockUseNodes = vi.hoisted(() => vi.fn())
|
||||
const mockUsePanelInteractions = vi.hoisted(() => vi.fn())
|
||||
const mockUseStore = vi.hoisted(() => vi.fn())
|
||||
const mockPanelOperatorPopup = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useClickAway: (...args: unknown[]) => mockUseClickAway(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({
|
||||
__esModule: true,
|
||||
default: () => mockUseNodes(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
usePanelInteractions: () => mockUsePanelInteractions(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => mockUseStore(selector),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup', () => ({
|
||||
__esModule: true,
|
||||
default: (props: {
|
||||
id: string
|
||||
data: Node['data']
|
||||
showHelpLink: boolean
|
||||
onClosePopup: () => void
|
||||
}) => {
|
||||
mockPanelOperatorPopup(props)
|
||||
return (
|
||||
<button type="button" onClick={props.onClosePopup}>
|
||||
{props.id}
|
||||
:
|
||||
{props.data.title}
|
||||
</button>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
describe('NodeContextmenu', () => {
|
||||
const mockHandleNodeContextmenuCancel = vi.fn()
|
||||
let nodeMenu: { nodeId: string, left: number, top: number } | undefined
|
||||
let nodes: Node[]
|
||||
let clickAwayHandler: (() => void) | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
nodeMenu = undefined
|
||||
nodes = [{
|
||||
id: 'node-1',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
title: 'Node 1',
|
||||
desc: '',
|
||||
type: 'code' as never,
|
||||
},
|
||||
} as Node]
|
||||
clickAwayHandler = undefined
|
||||
|
||||
mockUseClickAway.mockImplementation((handler: () => void) => {
|
||||
clickAwayHandler = handler
|
||||
})
|
||||
mockUseNodes.mockImplementation(() => nodes)
|
||||
mockUsePanelInteractions.mockReturnValue({
|
||||
handleNodeContextmenuCancel: mockHandleNodeContextmenuCancel,
|
||||
})
|
||||
mockUseStore.mockImplementation((selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => selector({ nodeMenu }))
|
||||
})
|
||||
|
||||
it('should stay hidden when the node menu is absent', () => {
|
||||
render(<NodeContextmenu />)
|
||||
|
||||
expect(screen.queryByRole('button')).not.toBeInTheDocument()
|
||||
expect(mockPanelOperatorPopup).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should stay hidden when the referenced node cannot be found', () => {
|
||||
nodeMenu = { nodeId: 'missing-node', left: 80, top: 120 }
|
||||
|
||||
render(<NodeContextmenu />)
|
||||
|
||||
expect(screen.queryByRole('button')).not.toBeInTheDocument()
|
||||
expect(mockPanelOperatorPopup).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render the popup at the stored position and close on popup/click-away actions', () => {
|
||||
nodeMenu = { nodeId: 'node-1', left: 80, top: 120 }
|
||||
const { container } = render(<NodeContextmenu />)
|
||||
|
||||
expect(screen.getByRole('button')).toHaveTextContent('node-1:Node 1')
|
||||
expect(mockPanelOperatorPopup).toHaveBeenCalledWith(expect.objectContaining({
|
||||
id: 'node-1',
|
||||
data: expect.objectContaining({ title: 'Node 1' }),
|
||||
showHelpLink: true,
|
||||
}))
|
||||
expect(container.firstChild).toHaveStyle({
|
||||
left: '80px',
|
||||
top: '120px',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
clickAwayHandler?.()
|
||||
|
||||
expect(mockHandleNodeContextmenuCancel).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
151
web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx
Normal file
151
web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx
Normal file
@ -0,0 +1,151 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import PanelContextmenu from '../panel-contextmenu'
|
||||
|
||||
const mockUseClickAway = vi.hoisted(() => vi.fn())
|
||||
const mockUseTranslation = vi.hoisted(() => vi.fn())
|
||||
const mockUseStore = vi.hoisted(() => vi.fn())
|
||||
const mockUseNodesInteractions = vi.hoisted(() => vi.fn())
|
||||
const mockUsePanelInteractions = vi.hoisted(() => vi.fn())
|
||||
const mockUseWorkflowStartRun = vi.hoisted(() => vi.fn())
|
||||
const mockUseOperator = vi.hoisted(() => vi.fn())
|
||||
const mockUseDSL = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useClickAway: (...args: unknown[]) => mockUseClickAway(...args),
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => mockUseTranslation(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: {
|
||||
panelMenu?: { left: number, top: number }
|
||||
clipboardElements: unknown[]
|
||||
setShowImportDSLModal: (visible: boolean) => void
|
||||
}) => unknown) => mockUseStore(selector),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useNodesInteractions: () => mockUseNodesInteractions(),
|
||||
usePanelInteractions: () => mockUsePanelInteractions(),
|
||||
useWorkflowStartRun: () => mockUseWorkflowStartRun(),
|
||||
useDSL: () => mockUseDSL(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/operator/hooks', () => ({
|
||||
useOperator: () => mockUseOperator(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/operator/add-block', () => ({
|
||||
__esModule: true,
|
||||
default: ({ renderTrigger }: { renderTrigger: () => ReactNode }) => (
|
||||
<div data-testid="add-block">{renderTrigger()}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/divider', () => ({
|
||||
__esModule: true,
|
||||
default: ({ className }: { className?: string }) => <div data-testid="divider" className={className} />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/shortcuts-name', () => ({
|
||||
__esModule: true,
|
||||
default: ({ keys }: { keys: string[] }) => <span data-testid={`shortcut-${keys.join('-')}`}>{keys.join('+')}</span>,
|
||||
}))
|
||||
|
||||
describe('PanelContextmenu', () => {
|
||||
const mockHandleNodesPaste = vi.fn()
|
||||
const mockHandlePaneContextmenuCancel = vi.fn()
|
||||
const mockHandleStartWorkflowRun = vi.fn()
|
||||
const mockHandleAddNote = vi.fn()
|
||||
const mockExportCheck = vi.fn()
|
||||
const mockSetShowImportDSLModal = vi.fn()
|
||||
let panelMenu: { left: number, top: number } | undefined
|
||||
let clipboardElements: unknown[]
|
||||
let clickAwayHandler: (() => void) | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
panelMenu = undefined
|
||||
clipboardElements = []
|
||||
clickAwayHandler = undefined
|
||||
|
||||
mockUseClickAway.mockImplementation((handler: () => void) => {
|
||||
clickAwayHandler = handler
|
||||
})
|
||||
mockUseTranslation.mockReturnValue({
|
||||
t: (key: string) => key,
|
||||
})
|
||||
mockUseStore.mockImplementation((selector: (state: {
|
||||
panelMenu?: { left: number, top: number }
|
||||
clipboardElements: unknown[]
|
||||
setShowImportDSLModal: (visible: boolean) => void
|
||||
}) => unknown) => selector({
|
||||
panelMenu,
|
||||
clipboardElements,
|
||||
setShowImportDSLModal: mockSetShowImportDSLModal,
|
||||
}))
|
||||
mockUseNodesInteractions.mockReturnValue({
|
||||
handleNodesPaste: mockHandleNodesPaste,
|
||||
})
|
||||
mockUsePanelInteractions.mockReturnValue({
|
||||
handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel,
|
||||
})
|
||||
mockUseWorkflowStartRun.mockReturnValue({
|
||||
handleStartWorkflowRun: mockHandleStartWorkflowRun,
|
||||
})
|
||||
mockUseOperator.mockReturnValue({
|
||||
handleAddNote: mockHandleAddNote,
|
||||
})
|
||||
mockUseDSL.mockReturnValue({
|
||||
exportCheck: mockExportCheck,
|
||||
})
|
||||
})
|
||||
|
||||
it('should stay hidden when the panel menu is absent', () => {
|
||||
render(<PanelContextmenu />)
|
||||
|
||||
expect(screen.queryByTestId('add-block')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should keep paste disabled when the clipboard is empty', () => {
|
||||
panelMenu = { left: 24, top: 48 }
|
||||
|
||||
render(<PanelContextmenu />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.pasteHere'))
|
||||
|
||||
expect(mockHandleNodesPaste).not.toHaveBeenCalled()
|
||||
expect(mockHandlePaneContextmenuCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render actions, position the menu, and execute each action', () => {
|
||||
panelMenu = { left: 24, top: 48 }
|
||||
clipboardElements = [{ id: 'copied-node' }]
|
||||
const { container } = render(<PanelContextmenu />)
|
||||
|
||||
expect(screen.getByTestId('add-block')).toHaveTextContent('common.addBlock')
|
||||
expect(screen.getByTestId('shortcut-alt-r')).toHaveTextContent('alt+r')
|
||||
expect(screen.getByTestId('shortcut-ctrl-v')).toHaveTextContent('ctrl+v')
|
||||
expect(container.firstChild).toHaveStyle({
|
||||
left: '24px',
|
||||
top: '48px',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('nodes.note.addNote'))
|
||||
fireEvent.click(screen.getByText('common.run'))
|
||||
fireEvent.click(screen.getByText('common.pasteHere'))
|
||||
fireEvent.click(screen.getByText('export'))
|
||||
fireEvent.click(screen.getByText('common.importDSL'))
|
||||
clickAwayHandler?.()
|
||||
|
||||
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleStartWorkflowRun).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleNodesPaste).toHaveBeenCalledTimes(1)
|
||||
expect(mockExportCheck).toHaveBeenCalledTimes(1)
|
||||
expect(mockSetShowImportDSLModal).toHaveBeenCalledWith(true)
|
||||
expect(mockHandlePaneContextmenuCancel).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
})
|
||||
@ -1,7 +1,7 @@
|
||||
import type { EventEmitter } from 'ahooks/lib/useEventEmitter'
|
||||
import type { EventEmitterValue } from '@/context/event-emitter'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { EventEmitterContext } from '@/context/event-emitter'
|
||||
import { DSLImportStatus } from '@/models/app'
|
||||
import UpdateDSLModal from '../update-dsl-modal'
|
||||
@ -16,10 +16,17 @@ class MockFileReader {
|
||||
}
|
||||
|
||||
vi.stubGlobal('FileReader', MockFileReader as unknown as typeof FileReader)
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockEmit = vi.fn()
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
error: vi.fn(),
|
||||
info: vi.fn(),
|
||||
success: vi.fn(),
|
||||
warning: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
const mockImportDSL = vi.fn()
|
||||
const mockImportDSLConfirm = vi.fn()
|
||||
vi.mock('@/service/apps', () => ({
|
||||
@ -59,6 +66,7 @@ vi.mock('@/app/components/app/create-from-dsl-modal/uploader', () => ({
|
||||
}))
|
||||
|
||||
describe('UpdateDSLModal', () => {
|
||||
const mockToastError = vi.mocked(toast.error)
|
||||
const defaultProps = {
|
||||
onCancel: vi.fn(),
|
||||
onBackup: vi.fn(),
|
||||
@ -91,11 +99,9 @@ describe('UpdateDSLModal', () => {
|
||||
const eventEmitter = { emit: mockEmit } as unknown as EventEmitter<EventEmitterValue>
|
||||
|
||||
return render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<EventEmitterContext.Provider value={{ eventEmitter }}>
|
||||
<UpdateDSLModal {...props} />
|
||||
</EventEmitterContext.Provider>
|
||||
</ToastContext.Provider>,
|
||||
<EventEmitterContext.Provider value={{ eventEmitter }}>
|
||||
<UpdateDSLModal {...props} />
|
||||
</EventEmitterContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
@ -152,9 +158,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -233,9 +237,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockImportDSL).not.toHaveBeenCalled()
|
||||
|
||||
@ -254,9 +256,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -274,9 +274,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -305,9 +303,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -334,9 +330,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -365,9 +359,7 @@ describe('UpdateDSLModal', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
}))
|
||||
expect(mockToastError).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -114,9 +114,12 @@ vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllMCPTools: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (payload: unknown) => mockNotify(payload),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: (message: string) => mockNotify({ type: 'success', message }),
|
||||
error: (message: string) => mockNotify({ type: 'error', message }),
|
||||
warning: (message: string) => mockNotify({ type: 'warning', message }),
|
||||
info: (message: string) => mockNotify({ type: 'info', message }),
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ import {
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import SearchBox from '@/app/components/plugins/marketplace/search-box'
|
||||
import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal'
|
||||
import AllTools from '@/app/components/workflow/block-selector/all-tools'
|
||||
@ -137,10 +137,7 @@ const ToolPicker: FC<Props> = ({
|
||||
|
||||
const doCreateCustomToolCollection = async (data: CustomCollectionBackend) => {
|
||||
await createCustomCollection(data)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
toast.success(t('api.actionSuccess', { ns: 'common' }))
|
||||
hideEditCustomCollectionModal()
|
||||
handleAddedCustomTool()
|
||||
}
|
||||
|
||||
@ -60,9 +60,12 @@ vi.mock('@/service/use-workflow', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../base/toast', () => ({
|
||||
default: {
|
||||
notify: (payload: unknown) => mockNotify(payload),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: (message: string) => mockNotify({ type: 'success', message }),
|
||||
error: (message: string) => mockNotify({ type: 'error', message }),
|
||||
warning: (message: string) => mockNotify({ type: 'warning', message }),
|
||||
info: (message: string) => mockNotify({ type: 'info', message }),
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
@ -46,10 +46,13 @@ vi.mock('../../hooks/use-dynamic-test-run-options', () => ({
|
||||
useDynamicTestRunOptions: () => mockDynamicOptions,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: (message: string) => mockNotify({ type: 'success', message }),
|
||||
error: (message: string) => mockNotify({ type: 'error', message }),
|
||||
warning: (message: string) => mockNotify({ type: 'warning', message }),
|
||||
info: (message: string) => mockNotify({ type: 'info', message }),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
|
||||
@ -4,11 +4,11 @@ import {
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { useInvalidAllLastRun, useRestoreWorkflow } from '@/service/use-workflow'
|
||||
import { getFlowPrefix } from '@/service/utils'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Toast from '../../base/toast'
|
||||
import {
|
||||
useWorkflowRefreshDraft,
|
||||
useWorkflowRun,
|
||||
@ -65,18 +65,12 @@ const HeaderInRestoring = ({
|
||||
workflowStore.setState({ isRestoring: false })
|
||||
workflowStore.setState({ backupDraft: undefined })
|
||||
handleRefreshWorkflowDraft()
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
toast.success(t('versionHistory.action.restoreSuccess', { ns: 'workflow' }))
|
||||
deleteAllInspectVars()
|
||||
invalidAllLastRun()
|
||||
}
|
||||
catch {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
|
||||
})
|
||||
toast.error(t('versionHistory.action.restoreFailure', { ns: 'workflow' }))
|
||||
}
|
||||
finally {
|
||||
onRestoreSettled?.()
|
||||
|
||||
@ -5,7 +5,7 @@ import { useCallback, useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks'
|
||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
@ -41,7 +41,6 @@ const RunMode = ({
|
||||
|
||||
const dynamicOptions = useDynamicTestRunOptions()
|
||||
const testRunMenuRef = useRef<TestRunMenuRef>(null)
|
||||
const { notify } = useToastContext()
|
||||
|
||||
useEffect(() => {
|
||||
// @ts-expect-error - Dynamic property for backward compatibility with keyboard shortcuts
|
||||
@ -66,7 +65,7 @@ const RunMode = ({
|
||||
isValid = false
|
||||
})
|
||||
if (!isValid) {
|
||||
notify({ type: 'error', message: t('panel.checklistTip', { ns: 'workflow' }) })
|
||||
toast.error(t('panel.checklistTip', { ns: 'workflow' }))
|
||||
return
|
||||
}
|
||||
|
||||
@ -98,7 +97,7 @@ const RunMode = ({
|
||||
// Placeholder for trigger-specific execution logic for schedule, webhook, plugin types
|
||||
console.log('TODO: Handle trigger execution for type:', option.type, 'nodeId:', option.nodeId)
|
||||
}
|
||||
}, [warningNodes, notify, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow])
|
||||
}, [warningNodes, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow])
|
||||
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
eventEmitter?.useSubscription((v: any) => {
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import HelpLine from '../index'
|
||||
|
||||
const mockUseViewport = vi.hoisted(() => vi.fn())
|
||||
const mockUseStore = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
useViewport: () => mockUseViewport(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: {
|
||||
helpLineHorizontal?: { top: number, left: number, width: number }
|
||||
helpLineVertical?: { top: number, left: number, height: number }
|
||||
}) => unknown) => mockUseStore(selector),
|
||||
}))
|
||||
|
||||
describe('HelpLine', () => {
|
||||
let helpLineHorizontal: { top: number, left: number, width: number } | undefined
|
||||
let helpLineVertical: { top: number, left: number, height: number } | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
helpLineHorizontal = undefined
|
||||
helpLineVertical = undefined
|
||||
|
||||
mockUseViewport.mockReturnValue({ x: 10, y: 20, zoom: 2 })
|
||||
mockUseStore.mockImplementation((selector: (state: {
|
||||
helpLineHorizontal?: { top: number, left: number, width: number }
|
||||
helpLineVertical?: { top: number, left: number, height: number }
|
||||
}) => unknown) => selector({
|
||||
helpLineHorizontal,
|
||||
helpLineVertical,
|
||||
}))
|
||||
})
|
||||
|
||||
it('should render nothing when both help lines are absent', () => {
|
||||
const { container } = render(<HelpLine />)
|
||||
|
||||
expect(container).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('should render the horizontal and vertical guide lines using viewport offsets and zoom', () => {
|
||||
helpLineHorizontal = { top: 30, left: 40, width: 50 }
|
||||
helpLineVertical = { top: 60, left: 70, height: 80 }
|
||||
|
||||
const { container } = render(<HelpLine />)
|
||||
const [horizontal, vertical] = Array.from(container.querySelectorAll('div'))
|
||||
|
||||
expect(horizontal).toHaveStyle({
|
||||
top: '80px',
|
||||
left: '90px',
|
||||
width: '100px',
|
||||
})
|
||||
expect(vertical).toHaveStyle({
|
||||
top: '140px',
|
||||
left: '150px',
|
||||
height: '160px',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -89,8 +89,13 @@ vi.mock('../index', () => ({
|
||||
useNodesMetaData: () => ({ nodes: [], nodesMap: mockNodesMap }),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({ notify: vi.fn() }),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
success: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warning: vi.fn(),
|
||||
info: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
import type { ModelConfig, VisionSetting } from '@/app/components/workflow/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { Resolution } from '@/types/app'
|
||||
import useConfigVision from '../use-config-vision'
|
||||
|
||||
const mockUseTextGenerationCurrentProviderAndModelAndModelList = vi.hoisted(() => vi.fn())
|
||||
const mockUseIsChatMode = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useTextGenerationCurrentProviderAndModelAndModelList: (...args: unknown[]) =>
|
||||
mockUseTextGenerationCurrentProviderAndModelAndModelList(...args),
|
||||
}))
|
||||
|
||||
vi.mock('../use-workflow', () => ({
|
||||
useIsChatMode: () => mockUseIsChatMode(),
|
||||
}))
|
||||
|
||||
const createModel = (overrides: Partial<ModelConfig> = {}): ModelConfig => ({
|
||||
provider: 'openai',
|
||||
name: 'gpt-4o',
|
||||
mode: 'chat',
|
||||
completion_params: [],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createVisionPayload = (overrides: Partial<{ enabled: boolean, configs?: VisionSetting }> = {}) => ({
|
||||
enabled: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useConfigVision', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseIsChatMode.mockReturnValue(false)
|
||||
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
|
||||
currentModel: {
|
||||
features: [],
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should expose vision capability and enable default chat configs for vision models', () => {
|
||||
const onChange = vi.fn()
|
||||
mockUseIsChatMode.mockReturnValue(true)
|
||||
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
|
||||
currentModel: {
|
||||
features: [ModelFeatureEnum.vision],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useConfigVision(createModel(), {
|
||||
payload: createVisionPayload(),
|
||||
onChange,
|
||||
}))
|
||||
|
||||
expect(result.current.isVisionModel).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.handleVisionResolutionEnabledChange(true)
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
enabled: true,
|
||||
configs: {
|
||||
detail: Resolution.high,
|
||||
variable_selector: ['sys', 'files'],
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear configs when disabling vision resolution', () => {
|
||||
const onChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useConfigVision(createModel(), {
|
||||
payload: createVisionPayload({
|
||||
enabled: true,
|
||||
configs: {
|
||||
detail: Resolution.low,
|
||||
variable_selector: ['node', 'files'],
|
||||
},
|
||||
}),
|
||||
onChange,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleVisionResolutionEnabledChange(false)
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
enabled: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should update the resolution config payload directly', () => {
|
||||
const onChange = vi.fn()
|
||||
const config: VisionSetting = {
|
||||
detail: Resolution.low,
|
||||
variable_selector: ['upstream', 'images'],
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useConfigVision(createModel(), {
|
||||
payload: createVisionPayload({ enabled: true }),
|
||||
onChange,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleVisionResolutionChange(config)
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
enabled: true,
|
||||
configs: config,
|
||||
})
|
||||
})
|
||||
|
||||
it('should disable vision settings when the selected model is no longer a vision model', () => {
|
||||
const onChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useConfigVision(createModel(), {
|
||||
payload: createVisionPayload({
|
||||
enabled: true,
|
||||
configs: {
|
||||
detail: Resolution.high,
|
||||
variable_selector: ['sys', 'files'],
|
||||
},
|
||||
}),
|
||||
onChange,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChanged()
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
enabled: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should reset enabled vision configs when the model changes but still supports vision', () => {
|
||||
const onChange = vi.fn()
|
||||
mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({
|
||||
currentModel: {
|
||||
features: [ModelFeatureEnum.vision],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useConfigVision(createModel(), {
|
||||
payload: createVisionPayload({
|
||||
enabled: true,
|
||||
configs: {
|
||||
detail: Resolution.low,
|
||||
variable_selector: ['old', 'files'],
|
||||
},
|
||||
}),
|
||||
onChange,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChanged()
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
enabled: true,
|
||||
configs: {
|
||||
detail: Resolution.high,
|
||||
variable_selector: [],
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,146 @@
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { BlockEnum } from '../../types'
|
||||
import { useDynamicTestRunOptions } from '../use-dynamic-test-run-options'
|
||||
|
||||
const mockUseTranslation = vi.hoisted(() => vi.fn())
|
||||
const mockUseNodes = vi.hoisted(() => vi.fn())
|
||||
const mockUseStore = vi.hoisted(() => vi.fn())
|
||||
const mockUseAllTriggerPlugins = vi.hoisted(() => vi.fn())
|
||||
const mockGetWorkflowEntryNode = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => mockUseTranslation(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({
|
||||
__esModule: true,
|
||||
default: () => mockUseNodes(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: {
|
||||
buildInTools: unknown[]
|
||||
customTools: unknown[]
|
||||
workflowTools: unknown[]
|
||||
mcpTools: unknown[]
|
||||
}) => unknown) => mockUseStore(selector),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-triggers', () => ({
|
||||
useAllTriggerPlugins: () => mockUseAllTriggerPlugins(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils/workflow-entry', () => ({
|
||||
getWorkflowEntryNode: (...args: unknown[]) => mockGetWorkflowEntryNode(...args),
|
||||
}))
|
||||
|
||||
describe('useDynamicTestRunOptions', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseTranslation.mockReturnValue({
|
||||
t: (key: string) => key,
|
||||
})
|
||||
mockUseStore.mockImplementation((selector: (state: {
|
||||
buildInTools: unknown[]
|
||||
customTools: unknown[]
|
||||
workflowTools: unknown[]
|
||||
mcpTools: unknown[]
|
||||
}) => unknown) => selector({
|
||||
buildInTools: [],
|
||||
customTools: [],
|
||||
workflowTools: [],
|
||||
mcpTools: [],
|
||||
}))
|
||||
mockUseAllTriggerPlugins.mockReturnValue({
|
||||
data: [{
|
||||
name: 'plugin-provider',
|
||||
icon: '/plugin-icon.png',
|
||||
}],
|
||||
})
|
||||
})
|
||||
|
||||
it('should build user input, trigger options, and a run-all option from workflow nodes', () => {
|
||||
mockUseNodes.mockReturnValue([
|
||||
{
|
||||
id: 'start-1',
|
||||
data: { type: BlockEnum.Start, title: 'User Input' },
|
||||
},
|
||||
{
|
||||
id: 'schedule-1',
|
||||
data: { type: BlockEnum.TriggerSchedule, title: 'Daily Schedule' },
|
||||
},
|
||||
{
|
||||
id: 'webhook-1',
|
||||
data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' },
|
||||
},
|
||||
{
|
||||
id: 'plugin-1',
|
||||
data: {
|
||||
type: BlockEnum.TriggerPlugin,
|
||||
title: '',
|
||||
plugin_name: 'Plugin Trigger',
|
||||
provider_id: 'plugin-provider',
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useDynamicTestRunOptions())
|
||||
|
||||
expect(result.current.userInput).toEqual(expect.objectContaining({
|
||||
id: 'start-1',
|
||||
type: 'user_input',
|
||||
name: 'User Input',
|
||||
nodeId: 'start-1',
|
||||
enabled: true,
|
||||
}))
|
||||
expect(result.current.triggers).toEqual([
|
||||
expect.objectContaining({
|
||||
id: 'schedule-1',
|
||||
type: 'schedule',
|
||||
name: 'Daily Schedule',
|
||||
nodeId: 'schedule-1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
id: 'webhook-1',
|
||||
type: 'webhook',
|
||||
name: 'Webhook Trigger',
|
||||
nodeId: 'webhook-1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
id: 'plugin-1',
|
||||
type: 'plugin',
|
||||
name: 'Plugin Trigger',
|
||||
nodeId: 'plugin-1',
|
||||
}),
|
||||
])
|
||||
expect(result.current.runAll).toEqual(expect.objectContaining({
|
||||
id: 'run-all',
|
||||
type: 'all',
|
||||
relatedNodeIds: ['schedule-1', 'webhook-1', 'plugin-1'],
|
||||
}))
|
||||
})
|
||||
|
||||
it('should fall back to the workflow entry node and omit run-all when only one trigger exists', () => {
|
||||
mockUseNodes.mockReturnValue([
|
||||
{
|
||||
id: 'webhook-1',
|
||||
data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' },
|
||||
},
|
||||
])
|
||||
mockGetWorkflowEntryNode.mockReturnValue({
|
||||
id: 'fallback-start',
|
||||
data: { type: BlockEnum.Start, title: '' },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useDynamicTestRunOptions())
|
||||
|
||||
expect(result.current.userInput).toEqual(expect.objectContaining({
|
||||
id: 'fallback-start',
|
||||
type: 'user_input',
|
||||
name: 'blocks.start',
|
||||
nodeId: 'fallback-start',
|
||||
}))
|
||||
expect(result.current.triggers).toHaveLength(1)
|
||||
expect(result.current.runAll).toBeUndefined()
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user