mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:06:51 +08:00
merge evaluation fe
This commit is contained in:
commit
fcd2b5fef4
22
.github/workflows/db-migration-test.yml
vendored
22
.github/workflows/db-migration-test.yml
vendored
@ -110,6 +110,28 @@ jobs:
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
|
||||
# to return (container processes started); it does not wait on healthcheck
|
||||
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
|
||||
# wait the migration runs while InnoDB is still initialising and gets
|
||||
# killed with "Lost connection during query". Poll a real SELECT until it
|
||||
# succeeds.
|
||||
- name: Wait for MySQL to accept queries
|
||||
run: |
|
||||
set +e
|
||||
for i in $(seq 1 60); do
|
||||
if docker run --rm --network host mysql:8.0 \
|
||||
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
|
||||
-e 'SELECT 1' >/dev/null 2>&1; then
|
||||
echo "MySQL ready after ${i}s"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL not ready after 60s; dumping container logs:"
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
|
||||
exit 1
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
@ -468,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -487,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -876,6 +886,22 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC):
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class _LazyEmbeddings(Embeddings):
|
||||
"""Lazy proxy that defers materializing the real embedding model.
|
||||
|
||||
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
|
||||
transitively calls ``FeatureService.get_features`` → ``BillingService``
|
||||
HTTP GETs (see ``provider_manager.py``). Cleanup paths
|
||||
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
|
||||
at all, so deferring this until an ``embed_*`` method is actually invoked
|
||||
keeps cleanup tasks resilient to transient billing-API failures and avoids
|
||||
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
|
||||
hiccups.
|
||||
|
||||
Existing callers that perform create / search operations are unaffected:
|
||||
the first ``embed_*`` call materializes the underlying model and the
|
||||
behavior is identical from that point on.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._real: Embeddings | None = None
|
||||
|
||||
def _ensure(self) -> Embeddings:
|
||||
if self._real is None:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
@ -60,7 +112,11 @@ class Vector:
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
|
||||
# never transitively trigger billing API calls during ``Vector(dataset)``
|
||||
# construction. The real embedding model is materialized only when an
|
||||
# ``embed_*`` method is actually invoked (i.e. create / search paths).
|
||||
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
|
||||
@ -2182,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
||||
return result
|
||||
|
||||
|
||||
class UploadFile(Base):
|
||||
class UploadFile(TypeBase):
|
||||
__tablename__ = "upload_files"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
@ -2190,9 +2190,12 @@ class UploadFile(Base):
|
||||
)
|
||||
|
||||
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
|
||||
# (especially when generating `source_url`).
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
# (especially when generating `source_url`) and keep model metadata portable across databases.
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
init=False,
|
||||
default_factory=lambda: str(uuid4()),
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
@ -2200,16 +2203,6 @@ class UploadFile(Base):
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'account'"),
|
||||
default=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
# The `created_by` field stores the ID of the entity that created this upload file.
|
||||
#
|
||||
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
|
||||
@ -2228,10 +2221,18 @@ class UploadFile(Base):
|
||||
# `used` may indicate whether the file has been utilized by another service.
|
||||
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'account'"),
|
||||
default=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
# `used_by` may indicate the ID of the user who utilized this file.
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
source_url: Mapped[str] = mapped_column(LongText, default="")
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -50,7 +50,7 @@ from libs.uuid_utils import uuidv7
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import AppMode, UploadFile
|
||||
from .model import AppMode
|
||||
|
||||
|
||||
from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase
|
||||
@ -64,6 +64,10 @@ from .account import Account
|
||||
from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
|
||||
# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships
|
||||
# must target the class object directly instead of relying on string lookup across registries.
|
||||
from .model import UploadFile
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
from .utils.file_input_compat import (
|
||||
build_file_from_mapping_without_lookup,
|
||||
@ -1145,8 +1149,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
@staticmethod
|
||||
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
|
||||
from .model import UploadFile
|
||||
|
||||
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
|
||||
file = session.scalars(stmt).first()
|
||||
assert file is not None, f"UploadFile with id {file_id} should exist but not"
|
||||
@ -1240,10 +1242,11 @@ class WorkflowNodeExecutionOffload(Base):
|
||||
)
|
||||
|
||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
||||
UploadFile,
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
|
||||
primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id,
|
||||
)
|
||||
|
||||
|
||||
@ -2017,10 +2020,11 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# Relationship to UploadFile
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
UploadFile,
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
|
||||
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
# Wrap vector / keyword index cleanup in try/except so that a transient
|
||||
# failure here (e.g. billing API hiccup propagated via FeatureService when
|
||||
# ModelManager is initialized inside ``Vector(dataset)``) does not abort
|
||||
# the entire task and leave document_segments / child_chunks / image_files
|
||||
# / metadata bindings stranded in PG. Mirrors the pattern already used in
|
||||
# ``clean_dataset_task`` so the document row's hard delete (already
|
||||
# committed by the caller) does not produce orphan PG rows just because
|
||||
# the vector backend or one of its transitive dependencies was unhappy.
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector / keyword index in clean_document_task, "
|
||||
"document_id=%s, dataset_id=%s, index_node_ids_count=%d. "
|
||||
"Continuing with PG / storage cleanup; vector orphans can be reaped later.",
|
||||
document_id,
|
||||
dataset_id,
|
||||
len(index_node_ids),
|
||||
)
|
||||
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
|
||||
@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
# Wrap vector / keyword index cleanup in try/except so that a transient
|
||||
# failure here (e.g. billing API hiccup propagated via FeatureService when
|
||||
# ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort
|
||||
# the task and leave the already-deleted documents' segments stranded in PG.
|
||||
# The Document rows are hard-deleted in the previous session block, so any
|
||||
# exception escaping this task would produce orphans that no later request
|
||||
# can reference back. Mirrors the pattern in ``clean_dataset_task``.
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector / keyword index in clean_notion_document_task, "
|
||||
"dataset_id=%s, document_ids=%s, index_node_ids_count=%d. "
|
||||
"Continuing with segment deletion; vector orphans can be reaped later.",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
len(total_index_node_ids),
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
|
||||
@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
_ = manager.upgrade_plugin(
|
||||
# Use the service that downloads and uploads the package to the daemon
|
||||
# first; calling manager.upgrade_plugin directly skips that step and the
|
||||
# daemon fails because the package never reaches its local bucket.
|
||||
_ = PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_unique_identifier,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
|
||||
|
||||
@ -655,14 +655,25 @@ class TestCleanNotionDocumentTask:
|
||||
# Note: This test successfully verifies database operations.
|
||||
# IndexProcessor verification would require more sophisticated mocking.
|
||||
|
||||
def test_clean_notion_document_task_database_transaction_rollback(
|
||||
def test_clean_notion_document_task_continues_when_index_processor_fails(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup task behavior when database operations fail.
|
||||
Index processor failure (e.g. transient billing API error propagated via
|
||||
``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding
|
||||
model) must NOT abort the cleanup task. The Document rows have already
|
||||
been hard-deleted in the first session block before vector cleanup runs,
|
||||
so any uncaught exception escaping the task would strand
|
||||
``DocumentSegment`` rows in PG with no parent ``Document``.
|
||||
|
||||
This test verifies that the task properly handles database errors
|
||||
and maintains data consistency.
|
||||
Contract: the task swallows the index_processor exception, logs it, and
|
||||
proceeds to delete the segments — leaving PG consistent. (Vector orphans,
|
||||
if any, can be reaped later by an offline scanner.)
|
||||
|
||||
Regression guard for the production incident where ``clean_document_task``
|
||||
/ ``clean_notion_document_task`` failed with
|
||||
``ValueError("Unable to retrieve billing information...")`` and left
|
||||
tens of thousands of orphan segments per affected tenant.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
@ -725,17 +736,28 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock index processor to raise an exception
|
||||
# Simulate the production failure mode: index_processor.clean() raises a
|
||||
# ValueError mirroring ``BillingService._send_request`` returning non-200.
|
||||
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
||||
mock_index_processor.clean.side_effect = ValueError(
|
||||
"Unable to retrieve billing information. Please try again later or contact support."
|
||||
)
|
||||
|
||||
# Execute cleanup task - current implementation propagates the exception
|
||||
with pytest.raises(Exception, match="Index processor error"):
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
# Execute cleanup task — must NOT raise even though clean() raises.
|
||||
# Before the safety-net wrapper this would have re-raised the ValueError,
|
||||
# aborting the task and leaving DocumentSegment stranded in PG.
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Note: This test demonstrates the task's error handling capability.
|
||||
# Even with external service errors, the database operations complete successfully.
|
||||
# In a production environment, proper error handling would determine transaction rollback behavior.
|
||||
# Vector cleanup was attempted exactly once.
|
||||
mock_index_processor.clean.assert_called_once()
|
||||
|
||||
# The crucial assertion: despite the index processor failure, the
|
||||
# final session block (line 51-52, ``DELETE FROM document_segments``)
|
||||
# still ran and committed. This is what the wrapper buys us — without
|
||||
# it the production incident left tens of thousands of orphan segments
|
||||
# per affected tenant. Aligns with the assertion shape used by the
|
||||
# happy-path test (``test_clean_notion_document_task_success``).
|
||||
assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0
|
||||
|
||||
def test_clean_notion_document_task_with_large_number_of_documents(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
|
||||
@ -23,6 +23,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.service_api.dataset.document import (
|
||||
DeprecatedDocumentAddByTextApi,
|
||||
DeprecatedDocumentUpdateByFileApi,
|
||||
DeprecatedDocumentUpdateByTextApi,
|
||||
DocumentAddByFileApi,
|
||||
DocumentAddByTextApi,
|
||||
@ -32,7 +33,6 @@ from controllers.service_api.dataset.document import (
|
||||
DocumentListQuery,
|
||||
DocumentTextCreatePayload,
|
||||
DocumentTextUpdate,
|
||||
DocumentUpdateByFileApi,
|
||||
DocumentUpdateByTextApi,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
@ -1095,8 +1095,8 @@ class TestArchivedDocumentImmutableError:
|
||||
assert error.code == 403
|
||||
|
||||
|
||||
class TestDocumentTextRouteDeprecation:
|
||||
"""Test that legacy underscore text routes stay marked deprecated."""
|
||||
class TestDocumentRouteDeprecation:
|
||||
"""Test that legacy document routes stay marked deprecated."""
|
||||
|
||||
def test_create_by_text_legacy_alias_is_deprecated(self):
|
||||
"""Ensure only the legacy create-by-text alias is marked deprecated."""
|
||||
@ -1108,10 +1108,15 @@ class TestDocumentTextRouteDeprecation:
|
||||
assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True
|
||||
assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_update_by_file_legacy_aliases_are_deprecated(self):
|
||||
"""Ensure only the legacy file-update aliases are marked deprecated."""
|
||||
assert DeprecatedDocumentUpdateByFileApi.post.__apidoc__["deprecated"] is True
|
||||
assert DocumentApi.patch.__apidoc__.get("deprecated") is not True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi,
|
||||
# DocumentUpdateByFileApi.
|
||||
# and the canonical/deprecated document file update routes.
|
||||
#
|
||||
# These controllers use ``@cloud_edition_billing_resource_check`` (does NOT
|
||||
# preserve ``__wrapped__``) and ``@cloud_edition_billing_rate_limit_check``
|
||||
@ -1359,13 +1364,52 @@ class TestDocumentAddByFileApiPost:
|
||||
api.post(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id)
|
||||
|
||||
|
||||
class TestDocumentUpdateByFileApiPost:
|
||||
"""Test suite for DocumentUpdateByFileApi.post() endpoint.
|
||||
class TestDocumentUpdateByFileApiPatch:
|
||||
"""Test suite for the canonical document file update endpoint.
|
||||
|
||||
``post`` is wrapped by ``@cloud_edition_billing_resource_check`` and
|
||||
``patch`` is wrapped by ``@cloud_edition_billing_resource_check`` and
|
||||
``@cloud_edition_billing_rate_limit_check``.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("route_name", ["update_by_file", "update-by-file"])
|
||||
@patch("controllers.service_api.dataset.document._update_document_by_file")
|
||||
@patch("controllers.service_api.wraps.FeatureService")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
def test_update_by_file_deprecated_aliases_delegate_to_shared_handler(
|
||||
self,
|
||||
mock_validate_token,
|
||||
mock_feature_svc,
|
||||
mock_update_document_by_file,
|
||||
route_name,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test legacy POST aliases still dispatch while marked deprecated."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_update_document_by_file.return_value = ({"document": {"id": "doc-1"}, "batch": "batch-1"}, 200)
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/{route_name}",
|
||||
method="POST",
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
):
|
||||
api = DeprecatedDocumentUpdateByFileApi()
|
||||
response, status = api.post(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id=doc_id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["batch"] == "batch-1"
|
||||
mock_update_document_by_file.assert_called_once_with(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id=doc_id,
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.document.db")
|
||||
@patch("controllers.service_api.wraps.FeatureService")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@ -1387,15 +1431,15 @@ class TestDocumentUpdateByFileApiPost:
|
||||
doc_id = str(uuid.uuid4())
|
||||
data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")}
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file",
|
||||
method="POST",
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}",
|
||||
method="PATCH",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
):
|
||||
api = DocumentUpdateByFileApi()
|
||||
api = DocumentApi()
|
||||
with pytest.raises(ValueError, match="Dataset does not exist"):
|
||||
api.post(
|
||||
api.patch(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id=doc_id,
|
||||
@ -1423,15 +1467,15 @@ class TestDocumentUpdateByFileApiPost:
|
||||
doc_id = str(uuid.uuid4())
|
||||
data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")}
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file",
|
||||
method="POST",
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}",
|
||||
method="PATCH",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
):
|
||||
api = DocumentUpdateByFileApi()
|
||||
api = DocumentApi()
|
||||
with pytest.raises(ValueError, match="External datasets"):
|
||||
api.post(
|
||||
api.patch(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id=doc_id,
|
||||
@ -1482,14 +1526,14 @@ class TestDocumentUpdateByFileApiPost:
|
||||
doc_id = str(uuid.uuid4())
|
||||
data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")}
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file",
|
||||
method="POST",
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}",
|
||||
method="PATCH",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
):
|
||||
api = DocumentUpdateByFileApi()
|
||||
response, status = api.post(
|
||||
api = DocumentApi()
|
||||
response, status = api.patch(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id=doc_id,
|
||||
|
||||
@ -146,10 +146,7 @@ def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module,
|
||||
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
with (
|
||||
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
|
||||
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
|
||||
):
|
||||
with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"):
|
||||
default_vector = vector_factory_module.Vector(dataset)
|
||||
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
|
||||
|
||||
@ -166,10 +163,57 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
"original_chunk_id",
|
||||
]
|
||||
assert custom_vector._attributes == ["doc_id"]
|
||||
assert default_vector._embeddings == "embeddings"
|
||||
# ``_embeddings`` is now a lazy proxy that defers materializing the real
|
||||
# embedding model until ``embed_*`` is invoked, so cleanup paths never
|
||||
# trigger billing/feature-service calls during ``Vector(dataset)``
|
||||
# construction. See ``_LazyEmbeddings``.
|
||||
assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings)
|
||||
assert default_vector._vector_processor == "processor"
|
||||
|
||||
|
||||
def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch):
|
||||
"""``Vector(dataset)`` must not transitively call ``ModelManager`` during
|
||||
construction. The real embedding model should only be materialized on the
|
||||
first ``embed_*`` call (i.e. create / search paths) so cleanup paths
|
||||
(``delete_by_ids`` / ``delete``) remain resilient to billing-API failures.
|
||||
"""
|
||||
for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly"))
|
||||
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock)
|
||||
|
||||
dataset = SimpleNamespace(
|
||||
tenant_id="tenant-1",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
proxy = vector_factory_module._LazyEmbeddings(dataset)
|
||||
|
||||
# Construction alone does not trigger ModelManager / FeatureService / BillingService.
|
||||
for_tenant_mock.assert_not_called()
|
||||
|
||||
# Exercising an embed_* method materializes the real model exactly once.
|
||||
inner_model = MagicMock()
|
||||
inner_model.embed_documents.return_value = [[0.1, 0.2]]
|
||||
cached_embedding_mock = MagicMock(return_value=inner_model)
|
||||
real_for_tenant = MagicMock()
|
||||
real_for_tenant.get_model_instance.return_value = "embedding-model-instance"
|
||||
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant))
|
||||
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock)
|
||||
|
||||
result = proxy.embed_documents(["hello"])
|
||||
|
||||
assert result == [[0.1, 0.2]]
|
||||
cached_embedding_mock.assert_called_once_with("embedding-model-instance")
|
||||
inner_model.embed_documents.assert_called_once_with(["hello"])
|
||||
|
||||
# Subsequent calls reuse the materialized model (no re-resolution).
|
||||
inner_model.embed_documents.reset_mock()
|
||||
cached_embedding_mock.reset_mock()
|
||||
proxy.embed_documents(["world"])
|
||||
cached_embedding_mock.assert_not_called()
|
||||
inner_model.embed_documents.assert_called_once_with(["world"])
|
||||
|
||||
|
||||
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
|
||||
calls = {"vector_type": None, "init_args": None}
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=graph,
|
||||
features=features,
|
||||
@ -58,7 +58,7 @@ class TestWorkflowModelValidation:
|
||||
# Assert
|
||||
assert workflow.tenant_id == tenant_id
|
||||
assert workflow.app_id == app_id
|
||||
assert workflow.type == WorkflowType.WORKFLOW.value
|
||||
assert workflow.type == WorkflowType.WORKFLOW
|
||||
assert workflow.version == "draft"
|
||||
assert workflow.graph == graph
|
||||
assert workflow.created_by == created_by
|
||||
@ -68,7 +68,7 @@ class TestWorkflowModelValidation:
|
||||
def test_workflow_type_enum_values(self):
|
||||
"""Test WorkflowType enum values."""
|
||||
# Assert
|
||||
assert WorkflowType.WORKFLOW.value == "workflow"
|
||||
assert WorkflowType.WORKFLOW == "workflow"
|
||||
assert WorkflowType.CHAT.value == "chat"
|
||||
assert WorkflowType.RAG_PIPELINE.value == "rag-pipeline"
|
||||
|
||||
@ -89,7 +89,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_data),
|
||||
features="{}",
|
||||
@ -114,7 +114,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features=json.dumps(features_data),
|
||||
@ -138,7 +138,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="v1.0",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
@ -176,11 +176,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -188,9 +188,9 @@ class TestWorkflowRunStateTransitions:
|
||||
assert workflow_run.tenant_id == tenant_id
|
||||
assert workflow_run.app_id == app_id
|
||||
assert workflow_run.workflow_id == workflow_id
|
||||
assert workflow_run.type == WorkflowType.WORKFLOW.value
|
||||
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
|
||||
assert workflow_run.type == WorkflowType.WORKFLOW
|
||||
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
assert workflow_run.created_by == created_by
|
||||
|
||||
def test_workflow_run_state_transition_running_to_succeeded(self):
|
||||
@ -200,21 +200,21 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
workflow_run.elapsed_time = 2.5
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert workflow_run.finished_at is not None
|
||||
assert workflow_run.elapsed_time == 2.5
|
||||
|
||||
@ -225,21 +225,21 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.FAILED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.FAILED
|
||||
workflow_run.error = "Node execution failed: Invalid input"
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.FAILED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.FAILED
|
||||
assert workflow_run.error == "Node execution failed: Invalid input"
|
||||
assert workflow_run.finished_at is not None
|
||||
|
||||
@ -250,20 +250,20 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.STOPPED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.STOPPED
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.STOPPED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.STOPPED
|
||||
assert workflow_run.finished_at is not None
|
||||
|
||||
def test_workflow_run_state_transition_running_to_paused(self):
|
||||
@ -273,19 +273,19 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
assert workflow_run.finished_at is None # Not finished when paused
|
||||
|
||||
def test_workflow_run_state_transition_paused_to_running(self):
|
||||
@ -295,19 +295,19 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING.value
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
def test_workflow_run_with_partial_succeeded_status(self):
|
||||
"""Test workflow run with partial-succeeded status."""
|
||||
@ -316,17 +316,17 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
exceptions_count=2,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
assert workflow_run.exceptions_count == 2
|
||||
|
||||
def test_workflow_run_with_inputs_and_outputs(self):
|
||||
@ -340,11 +340,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
inputs=json.dumps(inputs),
|
||||
outputs=json.dumps(outputs),
|
||||
@ -362,11 +362,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
graph=json.dumps(graph),
|
||||
)
|
||||
@ -391,11 +391,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
total_tokens=1500,
|
||||
total_steps=5,
|
||||
@ -410,7 +410,7 @@ class TestWorkflowRunStateTransitions:
|
||||
assert result["tenant_id"] == tenant_id
|
||||
assert result["app_id"] == app_id
|
||||
assert result["workflow_id"] == workflow_id
|
||||
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert result["total_tokens"] == 1500
|
||||
assert result["total_steps"] == 5
|
||||
|
||||
@ -422,18 +422,18 @@ class TestWorkflowRunStateTransitions:
|
||||
"tenant_id": str(uuid4()),
|
||||
"app_id": str(uuid4()),
|
||||
"workflow_id": str(uuid4()),
|
||||
"type": WorkflowType.WORKFLOW.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"type": WorkflowType.WORKFLOW,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
"version": "v1.0",
|
||||
"graph": {"nodes": [], "edges": []},
|
||||
"inputs": {"query": "test"},
|
||||
"status": WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
"status": WorkflowExecutionStatus.SUCCEEDED,
|
||||
"outputs": {"result": "success"},
|
||||
"error": None,
|
||||
"elapsed_time": 3.5,
|
||||
"total_tokens": 2000,
|
||||
"total_steps": 10,
|
||||
"created_by_role": CreatorUserRole.ACCOUNT.value,
|
||||
"created_by_role": CreatorUserRole.ACCOUNT,
|
||||
"created_by": str(uuid4()),
|
||||
"created_at": datetime.now(UTC),
|
||||
"finished_at": datetime.now(UTC),
|
||||
@ -446,7 +446,7 @@ class TestWorkflowRunStateTransitions:
|
||||
# Assert
|
||||
assert workflow_run.id == data["id"]
|
||||
assert workflow_run.workflow_id == data["workflow_id"]
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert workflow_run.total_tokens == 2000
|
||||
|
||||
|
||||
@ -467,14 +467,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -498,15 +498,15 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=2,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
node_id=current_node_id,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -528,8 +528,8 @@ class TestNodeExecutionRelationships:
|
||||
node_id="llm_test",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Test LLM",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -549,14 +549,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="llm_1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=json.dumps(inputs),
|
||||
outputs=json.dumps(outputs),
|
||||
@ -575,24 +575,24 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="code_1",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
title="Code Node",
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act - transition to succeeded
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.elapsed_time = 1.2
|
||||
node_execution.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert node_execution.elapsed_time == 1.2
|
||||
assert node_execution.finished_at is not None
|
||||
|
||||
@ -606,20 +606,20 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=3,
|
||||
node_id="code_1",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
title="Code Node",
|
||||
status=WorkflowNodeExecutionStatus.FAILED.value,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED.value
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert node_execution.error == error_message
|
||||
|
||||
def test_node_execution_with_metadata(self):
|
||||
@ -637,14 +637,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="llm_1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
execution_metadata=json.dumps(metadata),
|
||||
)
|
||||
@ -660,14 +660,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
execution_metadata=None,
|
||||
)
|
||||
@ -696,14 +696,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id=f"{node_type}_1",
|
||||
node_type=node_type,
|
||||
title=title,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -734,7 +734,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -761,7 +761,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -802,7 +802,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -835,11 +835,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
graph=json.dumps(original_graph),
|
||||
)
|
||||
@ -872,7 +872,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -912,7 +912,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -933,7 +933,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=None,
|
||||
features="{}",
|
||||
@ -956,11 +956,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=None,
|
||||
)
|
||||
@ -978,11 +978,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
outputs=None,
|
||||
)
|
||||
@ -1000,14 +1000,14 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=None,
|
||||
)
|
||||
@ -1025,14 +1025,14 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
70
api/tests/unit_tests/oss/__mock/baidu_obs.py
Normal file
70
api/tests/unit_tests/oss/__mock/baidu_obs.py
Normal file
@ -0,0 +1,70 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
get_example_data,
|
||||
get_example_filename,
|
||||
get_example_filepath,
|
||||
)
|
||||
|
||||
|
||||
class MockBaiduObsClass:
|
||||
def __init__(self, config=None):
|
||||
self.bucket_name = get_example_bucket()
|
||||
self.key = get_example_filename()
|
||||
self.content = get_example_data()
|
||||
self.filepath = get_example_filepath()
|
||||
|
||||
def put_object(self, bucket_name, key, data, content_length=None, content_md5=None, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
assert data == self.content
|
||||
assert content_length == len(self.content)
|
||||
expected_md5 = base64.standard_b64encode(hashlib.md5(self.content).digest())
|
||||
assert content_md5 == expected_md5
|
||||
|
||||
def get_object(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
return SimpleNamespace(data=BytesIO(self.content))
|
||||
|
||||
def get_object_to_file(self, bucket_name, key, file_name, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
assert file_name == self.filepath
|
||||
|
||||
def get_object_meta_data(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
return SimpleNamespace(status=200)
|
||||
|
||||
def delete_object(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baidu_obs_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(BosClient, "__init__", MockBaiduObsClass.__init__)
|
||||
monkeypatch.setattr(BosClient, "put_object", MockBaiduObsClass.put_object)
|
||||
monkeypatch.setattr(BosClient, "get_object", MockBaiduObsClass.get_object)
|
||||
monkeypatch.setattr(BosClient, "get_object_to_file", MockBaiduObsClass.get_object_to_file)
|
||||
monkeypatch.setattr(BosClient, "get_object_meta_data", MockBaiduObsClass.get_object_meta_data)
|
||||
monkeypatch.setattr(BosClient, "delete_object", MockBaiduObsClass.delete_object)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
1
api/tests/unit_tests/oss/baidu_obs/__init__.py
Normal file
1
api/tests/unit_tests/oss/baidu_obs/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
60
api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py
Normal file
60
api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py
Normal file
@ -0,0 +1,60 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
|
||||
from extensions.storage.baidu_obs_storage import BaiduObsStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
BaseStorageTest,
|
||||
get_example_bucket,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.unit_tests.oss.__mock.baidu_obs",)
|
||||
|
||||
|
||||
class TestBaiduObs(BaseStorageTest):
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, setup_baidu_obs_mock):
|
||||
"""Executed before each test method."""
|
||||
with (
|
||||
patch.object(BceCredentials, "__init__", return_value=None),
|
||||
patch.object(BceClientConfiguration, "__init__", return_value=None),
|
||||
):
|
||||
self.storage = BaiduObsStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
|
||||
|
||||
class TestBaiduObsConfiguration:
|
||||
def test_init_with_config(self):
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.BAIDU_OBS_BUCKET_NAME = "test-bucket"
|
||||
mock_dify_config.BAIDU_OBS_ACCESS_KEY = "test-access-key"
|
||||
mock_dify_config.BAIDU_OBS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.BAIDU_OBS_ENDPOINT = "https://bj.bcebos.com"
|
||||
|
||||
mock_credentials = MagicMock(name="credentials")
|
||||
mock_config = MagicMock(name="config")
|
||||
mock_client = MagicMock(name="client")
|
||||
|
||||
with (
|
||||
patch("extensions.storage.baidu_obs_storage.dify_config", mock_dify_config),
|
||||
patch("extensions.storage.baidu_obs_storage.BceCredentials", return_value=mock_credentials) as credentials,
|
||||
patch(
|
||||
"extensions.storage.baidu_obs_storage.BceClientConfiguration", return_value=mock_config
|
||||
) as configuration,
|
||||
patch("extensions.storage.baidu_obs_storage.BosClient", return_value=mock_client) as client_cls,
|
||||
):
|
||||
storage = BaiduObsStorage()
|
||||
|
||||
assert storage.bucket_name == "test-bucket"
|
||||
assert storage.client == mock_client
|
||||
credentials.assert_called_once_with(
|
||||
access_key_id="test-access-key",
|
||||
secret_access_key="test-secret-key",
|
||||
)
|
||||
configuration.assert_called_once_with(
|
||||
credentials=mock_credentials,
|
||||
endpoint="https://bj.bcebos.com",
|
||||
)
|
||||
client_cls.assert_called_once_with(config=mock_config)
|
||||
291
api/tests/unit_tests/tasks/test_clean_document_task.py
Normal file
291
api/tests/unit_tests/tasks/test_clean_document_task.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Unit tests for clean_document_task.
|
||||
|
||||
Focuses on the resilience contract added by the billing-failure fix:
|
||||
``index_processor.clean()`` is wrapped in ``try/except`` so that a transient
|
||||
failure inside the vector / keyword cleanup (e.g. ``ValueError("Unable to
|
||||
retrieve billing information...")`` raised by ``BillingService._send_request``
|
||||
when ``Vector(dataset)`` transitively triggers ``FeatureService.get_features``)
|
||||
does not abort the entire task and leave PG with stranded ``DocumentSegment``
|
||||
/ ``ChildChunk`` / ``UploadFile`` / ``DatasetMetadataBinding`` rows.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks.clean_document_task import clean_document_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Patch ``session_factory.create_session`` to return per-call mock sessions.
|
||||
|
||||
Each call to ``create_session()`` yields a fresh ``MagicMock`` session so we
|
||||
can assert ``execute()`` calls across the multiple short-lived transactions
|
||||
used by ``clean_document_task``.
|
||||
"""
|
||||
with patch("tasks.clean_document_task.session_factory", autospec=True) as mock_sf:
|
||||
sessions: list[MagicMock] = []
|
||||
|
||||
def _create_session():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
yield mock_sf, sessions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
with patch("tasks.clean_document_task.storage", autospec=True) as mock:
|
||||
mock.delete.return_value = None
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock ``IndexProcessorFactory`` so we can inject behavior into ``clean``."""
|
||||
with patch("tasks.clean_document_task.IndexProcessorFactory", autospec=True) as factory_cls:
|
||||
processor = MagicMock()
|
||||
processor.clean.return_value = None
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = processor
|
||||
factory_cls.return_value = factory_instance
|
||||
|
||||
yield {
|
||||
"factory_cls": factory_cls,
|
||||
"factory_instance": factory_instance,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
|
||||
def _build_segment(segment_id: str, content: str = "segment content") -> MagicMock:
|
||||
seg = MagicMock()
|
||||
seg.id = segment_id
|
||||
seg.index_node_id = f"node-{segment_id}"
|
||||
seg.content = content
|
||||
return seg
|
||||
|
||||
|
||||
def _build_dataset(dataset_id: str, tenant_id: str) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.tenant_id = tenant_id
|
||||
return ds
|
||||
|
||||
|
||||
class TestVectorCleanupResilience:
|
||||
"""Vector / keyword cleanup must not abort the task on transient failure."""
|
||||
|
||||
def test_billing_failure_during_vector_cleanup_does_not_skip_pg_cleanup(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""Reproduces the production incident:
|
||||
|
||||
``Vector(dataset)`` transitively calls ``FeatureService.get_features``
|
||||
which calls ``BillingService._send_request("GET", ...)``. When billing
|
||||
returns non-200 it raises ``ValueError("Unable to retrieve billing
|
||||
information...")``. Before the fix this propagated out of
|
||||
``clean_document_task`` and left ``DocumentSegment`` / ``ChildChunk`` /
|
||||
``UploadFile`` / ``DatasetMetadataBinding`` rows orphaned because the
|
||||
already-deleted ``Document`` row had been hard-committed by the caller
|
||||
(``dataset_service.delete_document``) before ``.delay()`` was invoked.
|
||||
|
||||
Contract: a billing failure inside ``index_processor.clean()`` must be
|
||||
caught, logged, and the rest of the task must continue so PG ends up
|
||||
consistent with the deleted ``Document`` even if Qdrant retains
|
||||
orphan vectors that can be reaped later.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
# First create_session(): Step 1 (load segments + attachments).
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [
|
||||
_build_segment("seg-1"),
|
||||
_build_segment("seg-2"),
|
||||
]
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
# Second create_session(): Step 2 (vector cleanup). Returns dataset.
|
||||
step2_session = MagicMock()
|
||||
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session.scalars.return_value.all.return_value = []
|
||||
step2_session.execute.return_value.all.return_value = []
|
||||
# Subsequent sessions: Step 3+ (image / segment / file / metadata cleanup).
|
||||
# Default fixture returns empty results which is fine for these short txns.
|
||||
cm1, cm2 = MagicMock(), MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
cm2.__enter__.return_value = step2_session
|
||||
cm2.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
|
||||
|
||||
# Simulate the production failure: index_processor.clean() raises ValueError
|
||||
# mirroring BillingService._send_request when billing returns non-200.
|
||||
mock_index_processor_factory["processor"].clean.side_effect = ValueError(
|
||||
"Unable to retrieve billing information. Please try again later or contact support."
|
||||
)
|
||||
|
||||
# Act — must not raise out of the task even though clean() raises.
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# 1. Vector cleanup was attempted.
|
||||
mock_index_processor_factory["processor"].clean.assert_called_once()
|
||||
# 2. Despite the failure the task continued: at least one DocumentSegment
|
||||
# delete was issued. We use the count of session.execute calls across
|
||||
# later short transactions as a proxy for "Step 3+ executed".
|
||||
execute_calls = sum(s.execute.call_count for s in sessions)
|
||||
assert execute_calls > 0, (
|
||||
"Step 3+ DB cleanup did not run after vector cleanup failure; "
|
||||
"this regression would re-introduce the orphan-segment bug."
|
||||
)
|
||||
|
||||
def test_vector_cleanup_success_path_remains_unaffected(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""Backward-compat: the happy path must still call ``clean()`` exactly
|
||||
once with the expected arguments and complete without errors.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [_build_segment("seg-1")]
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session = MagicMock()
|
||||
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session.scalars.return_value.all.return_value = []
|
||||
step2_session.execute.return_value.all.return_value = []
|
||||
cm1, cm2 = MagicMock(), MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
cm2.__enter__.return_value = step2_session
|
||||
cm2.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
|
||||
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
assert mock_index_processor_factory["processor"].clean.call_count == 1
|
||||
# Index cleanup invoked with the expected delete_summaries / delete_child_chunks flags.
|
||||
_, kwargs = mock_index_processor_factory["processor"].clean.call_args
|
||||
assert kwargs.get("with_keywords") is True
|
||||
assert kwargs.get("delete_child_chunks") is True
|
||||
assert kwargs.get("delete_summaries") is True
|
||||
|
||||
def test_no_segments_skips_vector_cleanup(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""When the document has no segments (e.g. indexing failed before
|
||||
producing any), vector cleanup must not be attempted — and therefore
|
||||
the new try/except wrapper does not change behavior here.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [] # no segments
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
cm1 = MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1] + [_default_cm() for _ in range(10)]
|
||||
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
# Vector cleanup is gated on ``index_node_ids``; when there are no
|
||||
# segments the IndexProcessorFactory path is never entered.
|
||||
mock_index_processor_factory["factory_cls"].assert_not_called()
|
||||
@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
MODULE = "tasks.process_tenant_plugin_autoupgrade_check_task"
|
||||
|
||||
|
||||
def _make_plugin(plugin_id: str, version: str, source=PluginInstallationSource.Marketplace):
|
||||
"""Build a minimal stand-in for a PluginInstallation entry returned by manager.list_plugins."""
|
||||
return SimpleNamespace(
|
||||
plugin_id=plugin_id,
|
||||
version=version,
|
||||
plugin_unique_identifier=f"{plugin_id}:{version}@deadbeef",
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def _make_manifest(plugin_id: str, latest_version: str) -> MarketplacePluginSnapshot:
|
||||
org, name = plugin_id.split("/", 1)
|
||||
return MarketplacePluginSnapshot(
|
||||
org=org,
|
||||
name=name,
|
||||
latest_version=latest_version,
|
||||
latest_package_identifier=f"{plugin_id}:{latest_version}@cafe1234",
|
||||
latest_package_url=f"https://marketplace.example/{plugin_id}/{latest_version}.difypkg",
|
||||
)
|
||||
|
||||
|
||||
def _run_task(
|
||||
*,
|
||||
plugins: list,
|
||||
manifests: list[MarketplacePluginSnapshot],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=None,
|
||||
include_plugins=None,
|
||||
):
|
||||
"""
|
||||
Execute the celery task synchronously with mocks for the plugin manager,
|
||||
the marketplace cache and PluginService.upgrade_plugin_with_marketplace.
|
||||
Returns the upgrade-call recorder so each test can assert on it.
|
||||
"""
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = plugins
|
||||
|
||||
upgrade_calls: list[tuple[str, str, str]] = []
|
||||
|
||||
def _record_upgrade(tenant_id, original, new):
|
||||
upgrade_calls.append((tenant_id, original, new))
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
|
||||
patch(
|
||||
f"{MODULE}.PluginService.upgrade_plugin_with_marketplace",
|
||||
side_effect=_record_upgrade,
|
||||
) as upgrade_mock,
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
strategy_setting,
|
||||
0,
|
||||
upgrade_mode,
|
||||
exclude_plugins or [],
|
||||
include_plugins or [],
|
||||
)
|
||||
|
||||
return upgrade_mock, upgrade_calls
|
||||
|
||||
|
||||
class TestUpgradeCallsMarketplaceService:
|
||||
"""
|
||||
Regression test for the bug where the auto-upgrade task called
|
||||
manager.upgrade_plugin directly, which skipped downloading the new package
|
||||
from marketplace and uploading it to the daemon. The daemon then failed with
|
||||
"package file not found" and the upgrade silently never completed.
|
||||
"""
|
||||
|
||||
def test_upgrade_routes_through_plugin_service(self):
|
||||
plugin = _make_plugin("acme/foo", "1.0.0")
|
||||
manifest = _make_manifest("acme/foo", "1.0.1")
|
||||
|
||||
upgrade_mock, calls = _run_task(plugins=[plugin], manifests=[manifest])
|
||||
|
||||
upgrade_mock.assert_called_once()
|
||||
assert calls == [("tenant-1", plugin.plugin_unique_identifier, manifest.latest_package_identifier)]
|
||||
|
||||
def test_does_not_call_manager_upgrade_plugin_directly(self):
|
||||
"""Locks in that we never go back to the broken path that bypassed download/upload."""
|
||||
plugin = _make_plugin("acme/foo", "1.0.0")
|
||||
manifest = _make_manifest("acme/foo", "1.0.1")
|
||||
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = [plugin]
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=[manifest]),
|
||||
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace"),
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
fake_manager.upgrade_plugin.assert_not_called()
|
||||
|
||||
|
||||
class TestStrategySetting:
|
||||
def test_disabled_strategy_skips_everything(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.1")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_fix_only_upgrades_patch_version(self):
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.5")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_called_once()
|
||||
assert calls[0][2].endswith(":1.0.5@cafe1234")
|
||||
|
||||
def test_fix_only_skips_minor_bump(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.1.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_fix_only_skips_major_bump(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "2.0.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_latest_strategy_skips_when_versions_equal(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestUpgradeMode:
|
||||
def test_mode_all_upgrades_every_marketplace_plugin(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 2
|
||||
upgraded_ids = sorted(c[1] for c in calls)
|
||||
assert upgraded_ids == sorted(p.plugin_unique_identifier for p in plugins)
|
||||
|
||||
def test_mode_all_skips_non_marketplace_sources(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0", source=PluginInstallationSource.Github),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
def test_mode_partial_only_upgrades_included_plugins(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
include_plugins=["acme/foo"],
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
def test_mode_exclude_skips_excluded_plugins(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["acme/bar"],
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
|
||||
class TestErrorIsolation:
|
||||
def test_one_plugin_failure_does_not_block_others(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = plugins
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
def _upgrade(tenant_id, original, new):
|
||||
seen.append(original)
|
||||
if "foo" in original:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
|
||||
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace", side_effect=_upgrade),
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert any("foo" in s for s in seen)
|
||||
assert any("bar" in s for s in seen)
|
||||
@ -59,19 +59,25 @@ services:
|
||||
- ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}:/var/lib/mysql
|
||||
ports:
|
||||
- "${EXPOSE_MYSQL_PORT:-3306}:3306"
|
||||
# mysqladmin ping passes during mysql:8.0's TCP-listening stage even while
|
||||
# the server is still finalising init, leading to "Lost connection during
|
||||
# query" on the first real query. Verify with a real SELECT instead.
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"mysqladmin",
|
||||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"mysql",
|
||||
"-h",
|
||||
"127.0.0.1",
|
||||
"-uroot",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
"-e",
|
||||
"SELECT 1",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
start_period: 20s
|
||||
|
||||
# The redis cache.
|
||||
redis:
|
||||
|
||||
@ -17,3 +17,10 @@ Feature: Share app publicly
|
||||
Given a workflow app has been published and shared via API
|
||||
When I open the shared app URL
|
||||
Then the shared app page should be accessible
|
||||
|
||||
@unauthenticated
|
||||
Scenario: Run a shared workflow app without authentication
|
||||
Given a workflow app has been published and shared via API
|
||||
When I open the shared app URL
|
||||
And I run the shared workflow app
|
||||
Then the shared workflow run should succeed
|
||||
|
||||
@ -37,3 +37,15 @@ Then('the shared app page should be accessible', async function (this: DifyWorld
|
||||
await expect(this.getPage()).toHaveURL(/\/(workflow|chat)\/[a-zA-Z0-9]+/, { timeout: 15_000 })
|
||||
await expect(this.getPage().locator('body')).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
When('I run the shared workflow app', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
const runButton = page.getByTestId('run-button')
|
||||
|
||||
await expect(runButton).toBeEnabled({ timeout: 15_000 })
|
||||
await runButton.click()
|
||||
})
|
||||
|
||||
Then('the shared workflow run should succeed', async function (this: DifyWorld) {
|
||||
await expect(this.getPage().getByTestId('status-icon-success')).toBeVisible({ timeout: 55_000 })
|
||||
})
|
||||
|
||||
@ -12,8 +12,10 @@ Given('a minimal runnable workflow draft has been synced', async function (this:
|
||||
|
||||
When('I run the workflow', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
await page.getByText('Test Run').click()
|
||||
await expect(page.getByText('Running').first()).toBeVisible({ timeout: 15_000 })
|
||||
const testRunButton = page.getByText('Test Run')
|
||||
|
||||
await expect(testRunButton).toBeVisible({ timeout: 15_000 })
|
||||
await testRunButton.click()
|
||||
})
|
||||
|
||||
Then('the workflow run should succeed', async function (this: DifyWorld) {
|
||||
|
||||
@ -3506,11 +3506,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/plugins/reference-setting-modal/auto-update-setting/strategy-picker.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/components/plugins/reference-setting-modal/auto-update-setting/types.ts": {
|
||||
"erasable-syntax-only/enums": {
|
||||
"count": 2
|
||||
|
||||
@ -154,7 +154,6 @@ const mockSnippet: SnippetDetail = {
|
||||
id: 'snippet-1',
|
||||
name: 'Social Media Repurposer',
|
||||
description: 'Turn one blog post into multiple social media variations.',
|
||||
author: 'Dify',
|
||||
updatedAt: '2026-03-25 10:00',
|
||||
usage: '12',
|
||||
icon: '🤖',
|
||||
|
||||
@ -11,7 +11,6 @@ const mockSnippet: SnippetDetail = {
|
||||
id: 'snippet-1',
|
||||
name: 'Social Media Repurposer',
|
||||
description: 'Turn one blog post into multiple social media variations.',
|
||||
author: 'Dify',
|
||||
updatedAt: '2026-03-25 10:00',
|
||||
usage: '12',
|
||||
icon: '🤖',
|
||||
|
||||
@ -160,8 +160,9 @@ const defaultSnippetData = {
|
||||
icon_url: '',
|
||||
},
|
||||
created_at: 1704067200,
|
||||
updated_at: '2024-01-02 10:00',
|
||||
author: '',
|
||||
created_by: 'user-1',
|
||||
updated_at: 1704153600,
|
||||
updated_by: 'user-2',
|
||||
},
|
||||
],
|
||||
total: 1,
|
||||
@ -321,8 +322,9 @@ describe('List', () => {
|
||||
icon_url: '',
|
||||
},
|
||||
created_at: 1704067200,
|
||||
updated_at: '2024-01-02 10:00',
|
||||
author: '',
|
||||
created_by: 'user-1',
|
||||
updated_at: 1704153600,
|
||||
updated_by: 'user-2',
|
||||
},
|
||||
]
|
||||
defaultSnippetData.pages[0]!.total = 1
|
||||
@ -678,8 +680,8 @@ describe('List', () => {
|
||||
})
|
||||
|
||||
it('should reuse the shared empty state when no snippets are available', () => {
|
||||
defaultSnippetData.pages[0].data = []
|
||||
defaultSnippetData.pages[0].total = 0
|
||||
defaultSnippetData.pages[0]!.data = []
|
||||
defaultSnippetData.pages[0]!.total = 0
|
||||
|
||||
renderList({ pageType: 'snippets' })
|
||||
|
||||
|
||||
@ -532,6 +532,7 @@ describe('useEmbeddedChatbot', () => {
|
||||
})
|
||||
|
||||
it('handleChangeConversation updates current conversation and refetches chat list', async () => {
|
||||
mockStoreState.embeddedConversationId = null
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
@ -548,6 +549,39 @@ describe('useEmbeddedChatbot', () => {
|
||||
expect(result.current.clearChatList).toBe(false)
|
||||
})
|
||||
|
||||
// Scenario: URL-provided conversation_id should take precedence over localStorage value.
|
||||
it('should prioritize URL conversation_id over localStorage', async () => {
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({
|
||||
'app-1': { 'embedded-user-1': 'stored-conv-id' },
|
||||
}))
|
||||
mockStoreState.embeddedConversationId = 'url-conv-id'
|
||||
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({
|
||||
user_id: 'embedded-user-1',
|
||||
conversation_id: 'url-conv-id',
|
||||
})
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.currentConversationId).toBe('url-conv-id')
|
||||
})
|
||||
})
|
||||
|
||||
// Scenario: When no URL conversation_id is provided, fall back to localStorage.
|
||||
it('should fall back to localStorage when no URL conversation_id is provided', async () => {
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({
|
||||
'app-1': { DEFAULT: 'stored-conv-id' },
|
||||
}))
|
||||
mockStoreState.embeddedConversationId = null
|
||||
mockStoreState.embeddedUserId = null
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.currentConversationId).toBe('stored-conv-id')
|
||||
})
|
||||
})
|
||||
|
||||
it('handleFeedback invokes updateFeedback service successfully', async () => {
|
||||
const { updateFeedback } = await import('@/service/share')
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
@ -113,7 +113,7 @@ export const useEmbeddedChatbot = (appSourceType: AppSourceType, tryAppId?: stri
|
||||
})
|
||||
}, [setConversationIdInfo])
|
||||
const allowResetChat = !conversationId
|
||||
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || conversationId || '', [appId, conversationIdInfo, userId, conversationId])
|
||||
const currentConversationId = useMemo(() => conversationId || conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId, conversationId])
|
||||
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
|
||||
if (appId) {
|
||||
let prevValue = conversationIdInfo?.[appId || '']
|
||||
|
||||
@ -43,7 +43,7 @@ describe('OptionListItem', () => {
|
||||
</OptionListItem>,
|
||||
)
|
||||
|
||||
const item = screen.getByRole('listitem')
|
||||
const item = screen.getByRole('button')
|
||||
expect(item).toHaveClass('bg-components-button-ghost-bg-hover')
|
||||
})
|
||||
|
||||
@ -54,7 +54,7 @@ describe('OptionListItem', () => {
|
||||
</OptionListItem>,
|
||||
)
|
||||
|
||||
const item = screen.getByRole('listitem')
|
||||
const item = screen.getByRole('button')
|
||||
expect(item).not.toHaveClass('bg-components-button-ghost-bg-hover')
|
||||
})
|
||||
})
|
||||
@ -100,7 +100,7 @@ describe('OptionListItem', () => {
|
||||
Clickable
|
||||
</OptionListItem>,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('listitem'))
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
@ -111,7 +111,7 @@ describe('OptionListItem', () => {
|
||||
Item
|
||||
</OptionListItem>,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('listitem'))
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(Element.prototype.scrollIntoView).toHaveBeenCalledWith({ behavior: 'smooth' })
|
||||
})
|
||||
@ -126,7 +126,7 @@ describe('OptionListItem', () => {
|
||||
</OptionListItem>,
|
||||
)
|
||||
|
||||
const item = screen.getByRole('listitem')
|
||||
const item = screen.getByRole('button')
|
||||
fireEvent.click(item)
|
||||
fireEvent.click(item)
|
||||
fireEvent.click(item)
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import OptionList from '../option-list'
|
||||
|
||||
describe('OptionList', () => {
|
||||
it('should render a scrollable list with hidden scrollbar styles', () => {
|
||||
render(
|
||||
<OptionList>
|
||||
<li>Item</li>
|
||||
</OptionList>,
|
||||
)
|
||||
|
||||
const list = screen.getByRole('list')
|
||||
|
||||
expect(list).toHaveClass('overflow-y-auto')
|
||||
expect(list).toHaveClass('[scrollbar-width:none]')
|
||||
expect(list).toHaveClass('[&::-webkit-scrollbar]:hidden')
|
||||
})
|
||||
|
||||
it('should append caller className after default classes', () => {
|
||||
render(
|
||||
<OptionList className="custom-list">
|
||||
<li>Item</li>
|
||||
</OptionList>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('list')).toHaveClass('custom-list')
|
||||
})
|
||||
})
|
||||
@ -1,4 +1,4 @@
|
||||
import type { FC } from 'react'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
@ -7,7 +7,8 @@ type OptionListItemProps = {
|
||||
isSelected: boolean
|
||||
onClick: () => void
|
||||
noAutoScroll?: boolean
|
||||
} & React.LiHTMLAttributes<HTMLLIElement>
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
const OptionListItem: FC<OptionListItemProps> = ({
|
||||
isSelected,
|
||||
@ -25,16 +26,21 @@ const OptionListItem: FC<OptionListItemProps> = ({
|
||||
return (
|
||||
<li
|
||||
ref={listItemRef}
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center justify-center rounded-md px-1.5 py-1 system-xs-medium text-components-button-ghost-text',
|
||||
isSelected ? 'bg-components-button-ghost-bg-hover' : 'hover:bg-components-button-ghost-bg-hover',
|
||||
)}
|
||||
onClick={() => {
|
||||
listItemRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
onClick()
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'flex w-full cursor-pointer items-center justify-center rounded-md px-1.5 py-1 system-xs-medium text-components-button-ghost-text outline-hidden',
|
||||
'focus-visible:ring-1 focus-visible:ring-components-input-border-hover focus-visible:ring-inset',
|
||||
isSelected ? 'bg-components-button-ghost-bg-hover' : 'hover:bg-components-button-ghost-bg-hover',
|
||||
)}
|
||||
onClick={() => {
|
||||
listItemRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
onClick()
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
</li>
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
import type { HTMLAttributes, ReactNode } from 'react'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import * as React from 'react'
|
||||
|
||||
type OptionListProps = {
|
||||
children: ReactNode
|
||||
} & HTMLAttributes<HTMLUListElement>
|
||||
|
||||
const optionListClassName = cn(
|
||||
'flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]',
|
||||
'[scrollbar-width:none] [&::-webkit-scrollbar]:hidden',
|
||||
)
|
||||
|
||||
const OptionList = ({
|
||||
children,
|
||||
className,
|
||||
...props
|
||||
}: OptionListProps) => {
|
||||
return (
|
||||
<ul className={cn(optionListClassName, className)} {...props}>
|
||||
{children}
|
||||
</ul>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(OptionList)
|
||||
@ -64,13 +64,13 @@ describe('TimePickerOptions', () => {
|
||||
it('should render selected hour in the list', () => {
|
||||
const props = createOptionsProps({ selectedTime: dayjs('2024-01-01 05:30:00') })
|
||||
render(<Options {...props} />)
|
||||
const selectedHour = screen.getAllByRole('listitem').find(item => item.textContent === '05')
|
||||
const selectedHour = screen.getAllByRole('button').find(item => item.textContent === '05')
|
||||
expect(selectedHour)!.toHaveClass('bg-components-button-ghost-bg-hover')
|
||||
})
|
||||
it('should render selected minute in the list', () => {
|
||||
const props = createOptionsProps({ selectedTime: dayjs('2024-01-01 05:30:00') })
|
||||
render(<Options {...props} />)
|
||||
const selectedMinute = screen.getAllByRole('listitem').find(item => item.textContent === '30')
|
||||
const selectedMinute = screen.getAllByRole('button').find(item => item.textContent === '30')
|
||||
expect(selectedMinute)!.toHaveClass('bg-components-button-ghost-bg-hover')
|
||||
})
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { TimeOptionsProps } from '../types'
|
||||
import * as React from 'react'
|
||||
import OptionList from '../common/option-list'
|
||||
import OptionListItem from '../common/option-list-item'
|
||||
import { useTimeOptions } from '../hooks'
|
||||
|
||||
@ -16,7 +17,7 @@ const Options: FC<TimeOptionsProps> = ({
|
||||
return (
|
||||
<div className="grid grid-cols-3 gap-x-1 p-2">
|
||||
{/* Hour */}
|
||||
<ul className="scrollbar-none flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]">
|
||||
<OptionList>
|
||||
{
|
||||
hourOptions.map((hour) => {
|
||||
const isSelected = selectedTime?.format('hh') === hour
|
||||
@ -31,9 +32,9 @@ const Options: FC<TimeOptionsProps> = ({
|
||||
)
|
||||
})
|
||||
}
|
||||
</ul>
|
||||
</OptionList>
|
||||
{/* Minute */}
|
||||
<ul className="scrollbar-none flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]">
|
||||
<OptionList>
|
||||
{
|
||||
(minuteFilter ? minuteFilter(minuteOptions) : minuteOptions).map((minute) => {
|
||||
const isSelected = selectedTime?.format('mm') === minute
|
||||
@ -48,9 +49,9 @@ const Options: FC<TimeOptionsProps> = ({
|
||||
)
|
||||
})
|
||||
}
|
||||
</ul>
|
||||
</OptionList>
|
||||
{/* Period */}
|
||||
<ul className="scrollbar-none flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]">
|
||||
<OptionList>
|
||||
{
|
||||
periodOptions.map((period) => {
|
||||
const isSelected = selectedTime?.format('A') === period
|
||||
@ -66,7 +67,7 @@ const Options: FC<TimeOptionsProps> = ({
|
||||
)
|
||||
})
|
||||
}
|
||||
</ul>
|
||||
</OptionList>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Placement } from '@floating-ui/react'
|
||||
import type { Placement } from '@langgenius/dify-ui/popover'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
|
||||
export enum ViewType {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { YearAndMonthPickerOptionsProps } from '../types'
|
||||
import * as React from 'react'
|
||||
import OptionList from '../common/option-list'
|
||||
import OptionListItem from '../common/option-list-item'
|
||||
import { useMonths, useYearOptions } from '../hooks'
|
||||
|
||||
@ -16,7 +17,7 @@ const Options: FC<YearAndMonthPickerOptionsProps> = ({
|
||||
return (
|
||||
<div className="grid grid-cols-2 gap-x-1 p-2">
|
||||
{/* Month Picker */}
|
||||
<ul className="scrollbar-none flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]">
|
||||
<OptionList>
|
||||
{
|
||||
months.map((month, index) => {
|
||||
const isSelected = selectedMonth === index
|
||||
@ -31,9 +32,9 @@ const Options: FC<YearAndMonthPickerOptionsProps> = ({
|
||||
)
|
||||
})
|
||||
}
|
||||
</ul>
|
||||
</OptionList>
|
||||
{/* Year Picker */}
|
||||
<ul className="scrollbar-none flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]">
|
||||
<OptionList>
|
||||
{
|
||||
yearOptions.map((year) => {
|
||||
const isSelected = selectedYear === year
|
||||
@ -48,7 +49,7 @@ const Options: FC<YearAndMonthPickerOptionsProps> = ({
|
||||
)
|
||||
})
|
||||
}
|
||||
</ul>
|
||||
</OptionList>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -42,7 +42,6 @@ const LanguageSelect: FC<ILanguageSelectProps> = ({
|
||||
placement="bottom-start"
|
||||
sideOffset={4}
|
||||
popupClassName="w-max"
|
||||
listClassName="no-scrollbar"
|
||||
>
|
||||
{supportedLanguages.map(({ prompt_name }) => (
|
||||
<SelectItem key={prompt_name} value={prompt_name}>
|
||||
|
||||
@ -55,7 +55,14 @@ vi.mock('../../hooks', async () => {
|
||||
})
|
||||
|
||||
vi.mock('../popup-item', () => ({
|
||||
default: ({ model }: { model: Model }) => <div>{model.provider}</div>,
|
||||
default: ({ model }: { model: Model }) => (
|
||||
<div>
|
||||
<span>{model.provider}</span>
|
||||
{model.models.map(modelItem => (
|
||||
<span key={modelItem.model}>{modelItem.model}</span>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
@ -207,6 +214,156 @@ describe('Popup', () => {
|
||||
expect((input as HTMLInputElement).value).toBe('')
|
||||
})
|
||||
|
||||
it('should show matching models when searching by model name', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[
|
||||
makeModel({
|
||||
models: [makeModelItem({ model: 'gpt-4', label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' } })],
|
||||
}),
|
||||
makeModel({
|
||||
provider: 'anthropic',
|
||||
label: { en_US: 'Anthropic', zh_Hans: 'Anthropic' },
|
||||
models: [makeModelItem({ model: 'claude-3', label: { en_US: 'Claude 3', zh_Hans: 'Claude 3' } })],
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText('datasetSettings.form.searchModel'),
|
||||
{ target: { value: 'claude' } },
|
||||
)
|
||||
|
||||
expect(screen.queryByText('openai')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('anthropic')).toBeInTheDocument()
|
||||
expect(screen.getByText('claude-3')).toBeInTheDocument()
|
||||
expect(screen.queryByText('gpt-4')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('No model found for \u201Cclaude\u201D')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show empty search placeholder when no provider or model name matches', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[
|
||||
makeModel({
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [
|
||||
makeModelItem({ model: 'gpt-4', label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' } }),
|
||||
],
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText('datasetSettings.form.searchModel'),
|
||||
{ target: { value: 'mistral' } },
|
||||
)
|
||||
|
||||
expect(screen.getByText('No model found for \u201Cmistral\u201D'))!.toBeInTheDocument()
|
||||
expect(screen.queryByText('openai')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('gpt-4')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show all models of a provider when searching by provider label', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[
|
||||
makeModel({
|
||||
provider: 'openai',
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [
|
||||
makeModelItem({ model: 'gpt-4', label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' } }),
|
||||
makeModelItem({ model: 'gpt-4o', label: { en_US: 'GPT-4o', zh_Hans: 'GPT-4o' } }),
|
||||
],
|
||||
}),
|
||||
makeModel({
|
||||
provider: 'anthropic',
|
||||
label: { en_US: 'Anthropic', zh_Hans: 'Anthropic' },
|
||||
models: [
|
||||
makeModelItem({ model: 'claude-3', label: { en_US: 'Claude 3', zh_Hans: 'Claude 3' } }),
|
||||
],
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText('datasetSettings.form.searchModel'),
|
||||
{ target: { value: 'openai' } },
|
||||
)
|
||||
|
||||
expect(screen.getByText('openai'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('gpt-4'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('gpt-4o'))!.toBeInTheDocument()
|
||||
expect(screen.queryByText('anthropic')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('claude-3')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should match by model provider key when model label does not contain the search text', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[
|
||||
makeModel({
|
||||
provider: 'azure_openai',
|
||||
label: { en_US: 'Azure', zh_Hans: 'Azure' },
|
||||
models: [
|
||||
makeModelItem({ model: 'gpt-4', label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' } }),
|
||||
],
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText('datasetSettings.form.searchModel'),
|
||||
{ target: { value: 'openai' } },
|
||||
)
|
||||
|
||||
expect(screen.getByText('azure_openai'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('gpt-4'))!.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should still apply scope features when matching by provider label', () => {
|
||||
mockSupportFunctionCall.mockReturnValue(false)
|
||||
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[
|
||||
makeModel({
|
||||
provider: 'openai',
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [
|
||||
makeModelItem({ model: 'gpt-4', features: [ModelFeatureEnum.vision] }),
|
||||
makeModelItem({ model: 'gpt-4-tool', features: [ModelFeatureEnum.toolCall] }),
|
||||
],
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
scopeFeatures={[ModelFeatureEnum.toolCall]}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText('datasetSettings.form.searchModel'),
|
||||
{ target: { value: 'openai' } },
|
||||
)
|
||||
|
||||
expect(screen.getByText('No model found for \u201Copenai\u201D'))!.toBeInTheDocument()
|
||||
expect(screen.queryByText('gpt-4')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('gpt-4-tool')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show compatible-only helper text when no scope features are applied', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
@ -219,8 +376,8 @@ describe('Popup', () => {
|
||||
expect(screen.queryByText('common.modelProvider.selector.onlyCompatibleModelsShown')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show compatible-only helper banner when scope features are applied', () => {
|
||||
const { container } = renderPopup(
|
||||
it('should show compatible-only helper text when scope features are applied', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[makeModel()]}
|
||||
onSelect={vi.fn()}
|
||||
@ -231,7 +388,26 @@ describe('Popup', () => {
|
||||
|
||||
expect(screen.getByTestId('compatible-models-banner'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.selector.onlyCompatibleModelsShown'))!.toBeInTheDocument()
|
||||
expect(container.querySelector('.i-ri-information-2-fill'))!.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should keep search and footer outside the scrollable model list', () => {
|
||||
renderPopup(
|
||||
<Popup
|
||||
modelList={[makeModel()]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
scopeFeatures={[ModelFeatureEnum.vision]}
|
||||
/>,
|
||||
)
|
||||
|
||||
const scrollRegion = screen.getByRole('region', { name: 'common.modelProvider.models' })
|
||||
const searchInput = screen.getByPlaceholderText('datasetSettings.form.searchModel')
|
||||
const settingsButton = screen.getByRole('button', { name: /common\.modelProvider\.selector\.modelProviderSettings/ })
|
||||
|
||||
expect(scrollRegion)!.toBeInTheDocument()
|
||||
expect(scrollRegion).not.toContainElement(searchInput)
|
||||
expect(scrollRegion).not.toContainElement(settingsButton)
|
||||
expect(scrollRegion).toContainElement(screen.getByTestId('compatible-models-banner'))
|
||||
})
|
||||
|
||||
it('should filter by scope features including toolCall and non-toolCall checks', () => {
|
||||
|
||||
@ -88,7 +88,7 @@ const ModelSelector: FC<ModelSelectorProps> = ({
|
||||
placement="bottom-start"
|
||||
sideOffset={4}
|
||||
className={popupClassName}
|
||||
popupClassName="overflow-hidden rounded-lg"
|
||||
popupClassName="overflow-hidden rounded-xl"
|
||||
popupProps={{ style: { minWidth: '320px', width: 'var(--anchor-width, auto)' } }}
|
||||
>
|
||||
<Popup
|
||||
|
||||
@ -0,0 +1,98 @@
|
||||
import type { FC } from 'react'
|
||||
import type { ModelProviderQuotaGetPaid } from '@/types/model-provider'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { modelNameMap, providerIconMap } from '../utils'
|
||||
|
||||
type MarketplaceSectionProps = {
|
||||
marketplaceProviders: ModelProviderQuotaGetPaid[]
|
||||
marketplaceCollapsed: boolean
|
||||
installingProvider: ModelProviderQuotaGetPaid | null
|
||||
isMarketplacePluginsLoading: boolean
|
||||
theme?: string
|
||||
onMarketplaceCollapsedChange: (collapsed: boolean) => void
|
||||
onInstallPlugin: (key: ModelProviderQuotaGetPaid) => void | Promise<void>
|
||||
}
|
||||
|
||||
const MarketplaceSection: FC<MarketplaceSectionProps> = ({
|
||||
marketplaceProviders,
|
||||
marketplaceCollapsed,
|
||||
installingProvider,
|
||||
isMarketplacePluginsLoading,
|
||||
theme,
|
||||
onMarketplaceCollapsedChange,
|
||||
onInstallPlugin,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
if (marketplaceProviders.length === 0)
|
||||
return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="py-2">
|
||||
<div className="h-px bg-divider-subtle" />
|
||||
</div>
|
||||
<div>
|
||||
<div className="flex h-[22px] items-center pr-2 pl-4">
|
||||
<div
|
||||
className="flex flex-1 cursor-pointer items-center system-sm-medium text-text-primary"
|
||||
onClick={() => onMarketplaceCollapsedChange(!marketplaceCollapsed)}
|
||||
>
|
||||
{t('modelProvider.selector.fromMarketplace', { ns: 'common' })}
|
||||
<span className={cn('i-custom-vender-solid-general-arrow-down-round-fill h-4 w-4 text-text-quaternary', marketplaceCollapsed && '-rotate-90')} />
|
||||
</div>
|
||||
</div>
|
||||
{!marketplaceCollapsed && (
|
||||
<div className="px-1 pb-1">
|
||||
{marketplaceProviders.map((key) => {
|
||||
const Icon = providerIconMap[key]
|
||||
const isInstalling = installingProvider === key
|
||||
return (
|
||||
<div
|
||||
key={key}
|
||||
className="group flex cursor-pointer items-center gap-1 rounded-lg py-0.5 pr-0.5 pl-3 hover:bg-state-base-hover"
|
||||
>
|
||||
<div className="flex flex-1 items-center gap-2 py-0.5">
|
||||
<Icon className="h-5 w-5 shrink-0 rounded-md" />
|
||||
<span className="system-sm-regular text-text-secondary">{modelNameMap[key]}</span>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className={cn(
|
||||
'shrink-0 backdrop-blur-[5px]',
|
||||
!isInstalling && 'hidden group-hover:flex',
|
||||
)}
|
||||
disabled={isInstalling || isMarketplacePluginsLoading}
|
||||
onClick={() => onInstallPlugin(key)}
|
||||
>
|
||||
{isInstalling && <span className="i-ri-loader-2-line h-3.5 w-3.5 animate-spin" />}
|
||||
{isInstalling
|
||||
? t('installModal.installing', { ns: 'plugin' })
|
||||
: t('modelProvider.selector.install', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
<a
|
||||
className="flex cursor-pointer items-center gap-0.5 px-3 py-1.5"
|
||||
href={getMarketplaceUrl('', { theme })}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<span className="flex-1 system-xs-regular text-text-accent">
|
||||
{t('modelProvider.selector.discoverMoreInMarketplace', { ns: 'common' })}
|
||||
</span>
|
||||
<span className="i-ri-arrow-right-up-line h-3! w-3! text-text-accent" />
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default MarketplaceSection
|
||||
@ -0,0 +1,39 @@
|
||||
import type { FC } from 'react'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type ModelSelectorEmptyStateProps = {
|
||||
onConfigure: () => void
|
||||
}
|
||||
|
||||
const ModelSelectorEmptyState: FC<ModelSelectorEmptyStateProps> = ({
|
||||
onConfigure,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="mx-2 flex flex-col gap-2 rounded-[10px] bg-linear-to-r from-state-base-hover to-background-gradient-mask-transparent p-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur-[5px]">
|
||||
<span className="i-ri-brain-2-line h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<p className="system-sm-medium text-text-secondary">
|
||||
{t('modelProvider.selector.noProviderConfigured', { ns: 'common' })}
|
||||
</p>
|
||||
<p className="system-xs-regular text-text-tertiary">
|
||||
{t('modelProvider.selector.noProviderConfiguredDesc', { ns: 'common' })}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
className="w-[108px]"
|
||||
onClick={onConfigure}
|
||||
>
|
||||
{t('modelProvider.selector.configure', { ns: 'common' })}
|
||||
<span className="i-ri-arrow-right-line h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelSelectorEmptyState
|
||||
@ -107,7 +107,8 @@ const PopupItem: FC<PopupItemProps> = ({
|
||||
|
||||
return (
|
||||
<div className="mb-1">
|
||||
<div className="sticky top-12 z-2 flex h-[22px] items-center justify-between bg-components-panel-bg px-3 text-xs font-medium text-text-tertiary">
|
||||
{/* Keep the sticky provider header above model rows while the list scrolls. */}
|
||||
<div className="sticky top-0 z-1 flex h-[22px] items-center justify-between bg-components-panel-bg px-3 text-xs font-medium text-text-tertiary">
|
||||
<div
|
||||
className="flex cursor-pointer items-center"
|
||||
onClick={() => setCollapsed(prev => !prev)}
|
||||
|
||||
@ -0,0 +1,130 @@
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import {
|
||||
ScrollAreaContent,
|
||||
ScrollAreaRoot,
|
||||
ScrollAreaScrollbar,
|
||||
ScrollAreaThumb,
|
||||
ScrollAreaViewport,
|
||||
} from '@langgenius/dify-ui/scroll-area'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type ModelSelectorPopupFrameProps = {
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
export const ModelSelectorPopupFrame: FC<ModelSelectorPopupFrameProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
return (
|
||||
<div className="flex max-h-[min(624px,var(--available-height,624px))] flex-col overflow-hidden rounded-xl bg-components-panel-bg">
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type ModelSelectorSearchHeaderProps = {
|
||||
searchText: string
|
||||
onSearchTextChange: (value: string) => void
|
||||
}
|
||||
|
||||
export const ModelSelectorSearchHeader: FC<ModelSelectorSearchHeaderProps> = ({
|
||||
searchText,
|
||||
onSearchTextChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="shrink-0 bg-components-panel-bg px-2 pt-2 pb-1">
|
||||
<div className={`
|
||||
flex h-8 items-center rounded-lg border px-2
|
||||
${searchText ? 'border-components-input-border-active bg-components-input-bg-active shadow-xs' : 'border-transparent bg-components-input-bg-normal'}
|
||||
`}
|
||||
>
|
||||
<span
|
||||
className={`
|
||||
mr-0.5 i-ri-search-line h-4 w-4 shrink-0
|
||||
${searchText ? 'text-text-tertiary' : 'text-text-quaternary'}
|
||||
`}
|
||||
/>
|
||||
<input
|
||||
className="block h-[18px] grow appearance-none bg-transparent px-1 text-[13px] text-text-primary outline-hidden"
|
||||
placeholder={t('form.searchModel', { ns: 'datasetSettings' }) || ''}
|
||||
value={searchText}
|
||||
onChange={e => onSearchTextChange(e.target.value)}
|
||||
/>
|
||||
{
|
||||
searchText && (
|
||||
<span
|
||||
className="ml-1.5 i-custom-vender-solid-general-x-circle h-[14px] w-[14px] shrink-0 cursor-pointer text-text-quaternary"
|
||||
onClick={() => onSearchTextChange('')}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type ModelSelectorScrollBodyProps = {
|
||||
children: ReactNode
|
||||
label: string
|
||||
}
|
||||
|
||||
export const ModelSelectorScrollBody: FC<ModelSelectorScrollBodyProps> = ({
|
||||
children,
|
||||
label,
|
||||
}) => {
|
||||
return (
|
||||
<ScrollAreaRoot className="relative min-h-0 overflow-hidden overscroll-contain">
|
||||
<ScrollAreaViewport
|
||||
aria-label={label}
|
||||
className="max-h-[calc(min(624px,var(--available-height,624px))-84px)] overscroll-contain"
|
||||
role="region"
|
||||
>
|
||||
<ScrollAreaContent className="min-w-0">
|
||||
{children}
|
||||
</ScrollAreaContent>
|
||||
</ScrollAreaViewport>
|
||||
{/* Keep the overlay scrollbar above sticky provider headers inside this scroll area. */}
|
||||
<ScrollAreaScrollbar className="z-2 data-[orientation=vertical]:my-1 data-[orientation=vertical]:me-1">
|
||||
<ScrollAreaThumb />
|
||||
</ScrollAreaScrollbar>
|
||||
</ScrollAreaRoot>
|
||||
)
|
||||
}
|
||||
|
||||
export const CompatibleModelsNotice = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="compatible-models-banner"
|
||||
className="px-4 py-2 system-xs-regular text-text-tertiary"
|
||||
>
|
||||
{t('modelProvider.selector.onlyCompatibleModelsShown', { ns: 'common' })}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
type ModelProviderSettingsFooterProps = {
|
||||
onOpenSettings: () => void
|
||||
}
|
||||
|
||||
export const ModelProviderSettingsFooter: FC<ModelProviderSettingsFooterProps> = ({
|
||||
onOpenSettings,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="shrink-0 border-t border-divider-subtle p-1">
|
||||
<button
|
||||
type="button"
|
||||
className="flex h-8 w-full cursor-pointer items-center gap-2 rounded-lg px-3 py-1 text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary"
|
||||
onClick={onOpenSettings}
|
||||
>
|
||||
<span className="i-ri-equalizer-2-line h-4 w-4 shrink-0" />
|
||||
<span className="system-xs-medium">{t('modelProvider.selector.modelProviderSettings', { ns: 'common' })}</span>
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -5,8 +5,6 @@ import type {
|
||||
ModelItem,
|
||||
} from '../declarations'
|
||||
import type { ModelProviderQuotaGetPaid } from '@/types/model-provider'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useTheme } from 'next-themes'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
@ -19,7 +17,6 @@ import { useProviderContext } from '@/context/provider-context'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { useInstallPackageFromMarketPlace } from '@/service/use-plugins'
|
||||
import { supportFunctionCall } from '@/utils/tool-call'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import {
|
||||
CustomConfigurationStatusEnum,
|
||||
ModelFeatureEnum,
|
||||
@ -29,8 +26,17 @@ import { useLanguage, useMarketplaceAllPlugins } from '../hooks'
|
||||
import CreditsExhaustedAlert from '../provider-added-card/model-auth-dropdown/credits-exhausted-alert'
|
||||
import { useTrialCredits } from '../provider-added-card/use-trial-credits'
|
||||
import { providerSupportsCredits } from '../supports-credits'
|
||||
import { MODEL_PROVIDER_QUOTA_GET_PAID, modelNameMap, providerIconMap, providerKeyToPluginId } from '../utils'
|
||||
import { MODEL_PROVIDER_QUOTA_GET_PAID, providerKeyToPluginId } from '../utils'
|
||||
import MarketplaceSection from './marketplace-section'
|
||||
import ModelSelectorEmptyState from './popup-empty-state'
|
||||
import PopupItem from './popup-item'
|
||||
import {
|
||||
CompatibleModelsNotice,
|
||||
ModelProviderSettingsFooter,
|
||||
ModelSelectorPopupFrame,
|
||||
ModelSelectorScrollBody,
|
||||
ModelSelectorSearchHeader,
|
||||
} from './popup-layout'
|
||||
|
||||
type PopupProps = {
|
||||
defaultModel?: DefaultModel
|
||||
@ -137,18 +143,26 @@ const Popup: FC<PopupProps> = ({
|
||||
}, [aiCreditVisibleProviders, installedProviderMap, modelList])
|
||||
|
||||
const filteredModelList = useMemo(() => {
|
||||
const normalizedSearch = searchText.toLowerCase()
|
||||
const matchesLabel = (label: Record<string, string>) => {
|
||||
if (label[language] !== undefined)
|
||||
return label[language].toLowerCase().includes(normalizedSearch)
|
||||
return Object.values(label).some(value =>
|
||||
value.toLowerCase().includes(normalizedSearch),
|
||||
)
|
||||
}
|
||||
|
||||
const filtered = installedModelList.map((model) => {
|
||||
const matchesProviderSearch = !searchText
|
||||
|| model.provider.toLowerCase().includes(searchText.toLowerCase())
|
||||
|| Object.values(model.label).some(label => label.toLowerCase().includes(searchText.toLowerCase()))
|
||||
const providerMatched = !!searchText && (
|
||||
matchesLabel(model.label)
|
||||
|| model.provider.toLowerCase().includes(normalizedSearch)
|
||||
)
|
||||
|
||||
const filteredModels = model.models
|
||||
.filter((modelItem) => {
|
||||
if (modelItem.label[language] !== undefined)
|
||||
return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase())
|
||||
return Object.values(modelItem.label).some(label =>
|
||||
label.toLowerCase().includes(searchText.toLowerCase()),
|
||||
)
|
||||
if (!searchText || providerMatched)
|
||||
return true
|
||||
return matchesLabel(modelItem.label)
|
||||
})
|
||||
.filter((modelItem) => {
|
||||
if (scopeFeatures.length === 0)
|
||||
@ -159,8 +173,12 @@ const Popup: FC<PopupProps> = ({
|
||||
return modelItem.features?.includes(feature) ?? false
|
||||
})
|
||||
})
|
||||
if (!matchesProviderSearch || (filteredModels.length === 0 && !aiCreditVisibleProviders.has(model.provider)))
|
||||
if (
|
||||
(searchText && filteredModels.length === 0)
|
||||
|| (!searchText && filteredModels.length === 0 && !aiCreditVisibleProviders.has(model.provider))
|
||||
) {
|
||||
return null
|
||||
}
|
||||
|
||||
return { ...model, models: filteredModels }
|
||||
}).filter((model): model is Model => model !== null)
|
||||
@ -181,166 +199,59 @@ const Popup: FC<PopupProps> = ({
|
||||
return MODEL_PROVIDER_QUOTA_GET_PAID.filter(key => !installedProviders.has(key))
|
||||
}, [modelProviders])
|
||||
|
||||
const handleOpenSettings = useCallback(() => {
|
||||
onHide()
|
||||
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER })
|
||||
}, [onHide, setShowAccountSettingModal])
|
||||
|
||||
return (
|
||||
<div className="no-scrollbar max-h-[480px] overflow-y-auto">
|
||||
<div className="sticky top-0 z-10 bg-components-panel-bg pt-3 pr-2 pb-1 pl-3">
|
||||
<div className={`
|
||||
flex h-8 items-center rounded-lg border pr-[10px] pl-[9px]
|
||||
${searchText ? 'border-components-input-border-active bg-components-input-bg-active shadow-xs' : 'border-transparent bg-components-input-bg-normal'}
|
||||
`}
|
||||
>
|
||||
<span
|
||||
className={`
|
||||
mr-[7px] i-ri-search-line h-[14px] w-[14px] shrink-0
|
||||
${searchText ? 'text-text-tertiary' : 'text-text-quaternary'}
|
||||
`}
|
||||
/>
|
||||
<input
|
||||
className="block h-[18px] grow appearance-none bg-transparent text-[13px] text-text-primary outline-hidden"
|
||||
placeholder={t('form.searchModel', { ns: 'datasetSettings' }) || ''}
|
||||
value={searchText}
|
||||
onChange={e => setSearchText(e.target.value)}
|
||||
/>
|
||||
{
|
||||
searchText && (
|
||||
<span
|
||||
className="ml-1.5 i-custom-vender-solid-general-x-circle h-[14px] w-[14px] shrink-0 cursor-pointer text-text-quaternary"
|
||||
onClick={() => setSearchText('')}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{scopeFeatures.length > 0 && (
|
||||
<div
|
||||
data-testid="compatible-models-banner"
|
||||
className="mt-2 flex items-center gap-1 rounded-lg bg-background-section-burn px-2.5 py-2"
|
||||
>
|
||||
<span className="i-ri-information-2-fill h-4 w-4 shrink-0 text-text-accent" />
|
||||
<p className="system-xs-medium text-text-secondary">
|
||||
{t('modelProvider.selector.onlyCompatibleModelsShown', { ns: 'common' })}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<ModelSelectorPopupFrame>
|
||||
<ModelSelectorSearchHeader
|
||||
searchText={searchText}
|
||||
onSearchTextChange={setSearchText}
|
||||
/>
|
||||
{showCreditsExhaustedAlert && (
|
||||
<CreditsExhaustedAlert hasApiKeyFallback={hasApiKeyFallback} />
|
||||
)}
|
||||
<div className="pr-1 pb-1 pl-3">
|
||||
{
|
||||
filteredModelList.map(model => (
|
||||
<PopupItem
|
||||
key={model.provider}
|
||||
defaultModel={defaultModel}
|
||||
model={model}
|
||||
onSelect={onSelect}
|
||||
onHide={onHide}
|
||||
<ModelSelectorScrollBody label={t('modelProvider.models', { ns: 'common' })}>
|
||||
<div className="pb-1">
|
||||
{
|
||||
filteredModelList.map(model => (
|
||||
<PopupItem
|
||||
key={model.provider}
|
||||
defaultModel={defaultModel}
|
||||
model={model}
|
||||
onSelect={onSelect}
|
||||
onHide={onHide}
|
||||
/>
|
||||
))
|
||||
}
|
||||
{!filteredModelList.length && !installedModelList.length && (
|
||||
<ModelSelectorEmptyState
|
||||
onConfigure={handleOpenSettings}
|
||||
/>
|
||||
))
|
||||
}
|
||||
{!filteredModelList.length && !installedModelList.length && (
|
||||
<div className="flex flex-col gap-2 rounded-[10px] bg-linear-to-r from-state-base-hover to-background-gradient-mask-transparent p-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur-[5px]">
|
||||
<span className="i-ri-brain-2-line h-5 w-5 text-text-tertiary" />
|
||||
)}
|
||||
{!filteredModelList.length && installedModelList.length > 0 && (
|
||||
<div className="px-3 py-1.5 text-center text-xs leading-[18px] break-all text-text-tertiary">
|
||||
{`No model found for \u201C${searchText}\u201D`}
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<p className="system-sm-medium text-text-secondary">
|
||||
{t('modelProvider.selector.noProviderConfigured', { ns: 'common' })}
|
||||
</p>
|
||||
<p className="system-xs-regular text-text-tertiary">
|
||||
{t('modelProvider.selector.noProviderConfiguredDesc', { ns: 'common' })}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
className="w-[108px]"
|
||||
onClick={() => {
|
||||
onHide()
|
||||
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER })
|
||||
}}
|
||||
>
|
||||
{t('modelProvider.selector.configure', { ns: 'common' })}
|
||||
<span className="i-ri-arrow-right-line h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
{!filteredModelList.length && installedModelList.length > 0 && (
|
||||
<div className="px-3 py-1.5 text-center text-xs leading-[18px] break-all text-text-tertiary">
|
||||
{`No model found for \u201C${searchText}\u201D`}
|
||||
</div>
|
||||
)}
|
||||
{marketplaceProviders.length > 0 && (
|
||||
<>
|
||||
<div className="mx-2 my-1 border-t border-divider-subtle" />
|
||||
<div className="mb-1">
|
||||
<div className="flex h-[22px] items-center px-3">
|
||||
<div
|
||||
className="flex flex-1 cursor-pointer items-center system-sm-medium text-text-primary"
|
||||
onClick={() => setMarketplaceCollapsed(prev => !prev)}
|
||||
>
|
||||
{t('modelProvider.selector.fromMarketplace', { ns: 'common' })}
|
||||
<span className={cn('i-custom-vender-solid-general-arrow-down-round-fill h-4 w-4 text-text-quaternary', marketplaceCollapsed && '-rotate-90')} />
|
||||
</div>
|
||||
</div>
|
||||
{!marketplaceCollapsed && (
|
||||
<>
|
||||
{marketplaceProviders.map((key) => {
|
||||
const Icon = providerIconMap[key]
|
||||
const isInstalling = installingProvider === key
|
||||
return (
|
||||
<div
|
||||
key={key}
|
||||
className="group flex cursor-pointer items-center gap-1 rounded-lg py-0.5 pr-0.5 pl-3 hover:bg-state-base-hover"
|
||||
>
|
||||
<div className="flex flex-1 items-center gap-2 py-0.5">
|
||||
<Icon className="h-5 w-5 shrink-0 rounded-md" />
|
||||
<span className="system-sm-regular text-text-secondary">{modelNameMap[key]}</span>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className={cn(
|
||||
'shrink-0 backdrop-blur-[5px]',
|
||||
!isInstalling && 'hidden group-hover:flex',
|
||||
)}
|
||||
disabled={isInstalling || isMarketplacePluginsLoading}
|
||||
onClick={() => handleInstallPlugin(key)}
|
||||
>
|
||||
{isInstalling && <span className="i-ri-loader-2-line h-3.5 w-3.5 animate-spin" />}
|
||||
{isInstalling
|
||||
? t('installModal.installing', { ns: 'plugin' })
|
||||
: t('modelProvider.selector.install', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
<a
|
||||
className="flex cursor-pointer items-center gap-0.5 px-3 pt-1.5"
|
||||
href={getMarketplaceUrl('', { theme })}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<span className="flex-1 system-xs-regular text-text-accent">
|
||||
{t('modelProvider.selector.discoverMoreInMarketplace', { ns: 'common' })}
|
||||
</span>
|
||||
<span className="i-ri-arrow-right-up-line h-3! w-3! text-text-accent" />
|
||||
</a>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className="sticky bottom-0 flex cursor-pointer items-center gap-1 rounded-b-lg border-t border-divider-subtle bg-components-panel-bg px-3 py-2 text-text-tertiary hover:text-text-secondary"
|
||||
onClick={() => {
|
||||
onHide()
|
||||
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER })
|
||||
}}
|
||||
>
|
||||
<span className="i-ri-equalizer-2-line h-4 w-4 shrink-0" />
|
||||
<span className="system-xs-medium">{t('modelProvider.selector.modelProviderSettings', { ns: 'common' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{scopeFeatures.length > 0 && (
|
||||
<CompatibleModelsNotice />
|
||||
)}
|
||||
<MarketplaceSection
|
||||
marketplaceProviders={marketplaceProviders}
|
||||
marketplaceCollapsed={marketplaceCollapsed}
|
||||
installingProvider={installingProvider}
|
||||
isMarketplacePluginsLoading={isMarketplacePluginsLoading}
|
||||
theme={theme}
|
||||
onMarketplaceCollapsedChange={setMarketplaceCollapsed}
|
||||
onInstallPlugin={handleInstallPlugin}
|
||||
/>
|
||||
</div>
|
||||
</ModelSelectorScrollBody>
|
||||
<ModelProviderSettingsFooter onOpenSettings={handleOpenSettings} />
|
||||
</ModelSelectorPopupFrame>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -2,13 +2,16 @@ import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useParams } from '@/next/navigation'
|
||||
import { useParams, usePathname, useRouter } from '@/next/navigation'
|
||||
import { useInfiniteAppList } from '@/service/use-apps'
|
||||
import { useCreateSnippetMutation, useInfiniteSnippetList, useSnippetApiDetail } from '@/service/use-snippets'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import AppNav from '../index'
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useParams: vi.fn(),
|
||||
usePathname: vi.fn(),
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
@ -29,6 +32,19 @@ vi.mock('@/service/use-apps', () => ({
|
||||
useInfiniteAppList: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-snippets', () => ({
|
||||
useCreateSnippetMutation: vi.fn(),
|
||||
useInfiniteSnippetList: vi.fn(),
|
||||
useSnippetApiDetail: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@langgenius/dify-ui/toast', () => ({
|
||||
toast: {
|
||||
error: vi.fn(),
|
||||
success: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/create-app-dialog', () => ({
|
||||
default: ({ show, onClose, onSuccess }: { show: boolean, onClose: () => void, onSuccess: () => void }) =>
|
||||
show
|
||||
@ -83,17 +99,67 @@ vi.mock('@/app/components/app/create-from-dsl-modal', () => ({
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/create-snippet-dialog', () => ({
|
||||
default: ({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
}: {
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
onConfirm: (payload: {
|
||||
name: string
|
||||
description: string
|
||||
icon: { type: 'emoji', icon: string, background: string }
|
||||
}) => void
|
||||
}) =>
|
||||
isOpen
|
||||
? (
|
||||
<button
|
||||
type="button"
|
||||
data-testid="create-snippet-dialog"
|
||||
onClick={() => {
|
||||
onConfirm({
|
||||
name: 'Created Snippet',
|
||||
description: '',
|
||||
icon: {
|
||||
type: 'emoji',
|
||||
icon: '🤖',
|
||||
background: '#fff',
|
||||
},
|
||||
})
|
||||
onClose()
|
||||
}}
|
||||
>
|
||||
Create Snippet
|
||||
</button>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('../../nav', () => ({
|
||||
default: ({
|
||||
createText,
|
||||
curNav,
|
||||
isApp,
|
||||
link,
|
||||
onCreate,
|
||||
onLoadMore,
|
||||
navigationItems,
|
||||
}: {
|
||||
createText: string
|
||||
curNav?: { id: string, name: string }
|
||||
isApp?: boolean
|
||||
link: string
|
||||
onCreate: (state: string) => void
|
||||
onLoadMore?: () => void
|
||||
navigationItems?: Array<{ id: string, name: string, link: string }>
|
||||
}) => (
|
||||
<div data-testid="nav">
|
||||
<div data-testid="nav-link">{link}</div>
|
||||
<div data-testid="nav-is-app">{String(isApp)}</div>
|
||||
<div data-testid="nav-create-text">{createText}</div>
|
||||
<div data-testid="nav-current">{curNav ? `${curNav.id}:${curNav.name}` : ''}</div>
|
||||
<ul data-testid="nav-items">
|
||||
{(navigationItems ?? []).map(item => (
|
||||
<li key={item.id}>{`${item.name} -> ${item.link}`}</li>
|
||||
@ -127,10 +193,52 @@ const mockAppData = [
|
||||
},
|
||||
]
|
||||
|
||||
const mockSnippetData = [
|
||||
{
|
||||
id: 'snippet-1',
|
||||
name: 'Snippet 1',
|
||||
description: '',
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '🧩',
|
||||
icon_background: '#fff',
|
||||
icon_url: null,
|
||||
},
|
||||
is_published: true,
|
||||
use_count: 0,
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 1,
|
||||
updated_by: 'user-1',
|
||||
},
|
||||
{
|
||||
id: 'snippet-2',
|
||||
name: 'Snippet 2',
|
||||
description: '',
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '⚙️',
|
||||
icon_background: '#000',
|
||||
icon_url: null,
|
||||
},
|
||||
is_published: true,
|
||||
use_count: 1,
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 1,
|
||||
updated_by: 'user-1',
|
||||
},
|
||||
]
|
||||
|
||||
const mockUseParams = vi.mocked(useParams)
|
||||
const mockUsePathname = vi.mocked(usePathname)
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
const mockUseAppStore = vi.mocked(useAppStore)
|
||||
const mockUseInfiniteAppList = vi.mocked(useInfiniteAppList)
|
||||
const mockUseInfiniteSnippetList = vi.mocked(useInfiniteSnippetList)
|
||||
const mockUseSnippetApiDetail = vi.mocked(useSnippetApiDetail)
|
||||
const mockUseCreateSnippetMutation = vi.mocked(useCreateSnippetMutation)
|
||||
let mockAppDetail: { id: string, name: string } | null = null
|
||||
|
||||
const setupDefaultMocks = (options?: {
|
||||
@ -144,6 +252,8 @@ const setupDefaultMocks = (options?: {
|
||||
const fetchNextPage = options?.fetchNextPage ?? vi.fn()
|
||||
|
||||
mockUseParams.mockReturnValue({ appId: 'app-1' } as ReturnType<typeof useParams>)
|
||||
mockUsePathname.mockReturnValue('/app/app-1/workflow')
|
||||
mockUseRouter.mockReturnValue({ push: vi.fn() } as unknown as ReturnType<typeof useRouter>)
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceEditor: options?.isEditor ?? false } as ReturnType<typeof useAppContext>)
|
||||
mockUseAppStore.mockImplementation((selector: unknown) => (selector as (state: { appDetail: { id: string, name: string } | null }) => unknown)({ appDetail: mockAppDetail }))
|
||||
mockUseInfiniteAppList.mockReturnValue({
|
||||
@ -153,10 +263,51 @@ const setupDefaultMocks = (options?: {
|
||||
isFetchingNextPage: false,
|
||||
refetch,
|
||||
} as ReturnType<typeof useInfiniteAppList>)
|
||||
mockUseInfiniteSnippetList.mockReturnValue({
|
||||
data: undefined,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
} as unknown as ReturnType<typeof useInfiniteSnippetList>)
|
||||
mockUseSnippetApiDetail.mockReturnValue({
|
||||
data: undefined,
|
||||
} as ReturnType<typeof useSnippetApiDetail>)
|
||||
mockUseCreateSnippetMutation.mockReturnValue({
|
||||
isPending: false,
|
||||
mutate: vi.fn(),
|
||||
} as unknown as ReturnType<typeof useCreateSnippetMutation>)
|
||||
|
||||
return { refetch, fetchNextPage }
|
||||
}
|
||||
|
||||
const setupSnippetMocks = (options?: {
|
||||
fetchNextPage?: () => void
|
||||
hasNextPage?: boolean
|
||||
mutate?: ReturnType<typeof vi.fn>
|
||||
}) => {
|
||||
const fetchNextPage = options?.fetchNextPage ?? vi.fn()
|
||||
const mutate = options?.mutate ?? vi.fn()
|
||||
|
||||
setupDefaultMocks({ isEditor: true })
|
||||
mockUseParams.mockReturnValue({ snippetId: 'snippet-1' } as ReturnType<typeof useParams>)
|
||||
mockUsePathname.mockReturnValue('/snippets/snippet-1/orchestrate')
|
||||
mockUseInfiniteSnippetList.mockReturnValue({
|
||||
data: { pages: [{ data: mockSnippetData }] },
|
||||
fetchNextPage,
|
||||
hasNextPage: options?.hasNextPage ?? false,
|
||||
isFetchingNextPage: false,
|
||||
} as unknown as ReturnType<typeof useInfiniteSnippetList>)
|
||||
mockUseSnippetApiDetail.mockReturnValue({
|
||||
data: mockSnippetData[0],
|
||||
} as ReturnType<typeof useSnippetApiDetail>)
|
||||
mockUseCreateSnippetMutation.mockReturnValue({
|
||||
isPending: false,
|
||||
mutate,
|
||||
} as unknown as ReturnType<typeof useCreateSnippetMutation>)
|
||||
|
||||
return { fetchNextPage, mutate }
|
||||
}
|
||||
|
||||
describe('AppNav', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@ -338,4 +489,67 @@ describe('AppNav', () => {
|
||||
expect(screen.getByText('App 1 -> /app/app-1/configuration')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should switch the main nav to snippet list and render snippet items on snippet detail routes', () => {
|
||||
setupSnippetMocks()
|
||||
|
||||
render(<AppNav />)
|
||||
|
||||
expect(screen.getByTestId('nav-link')).toHaveTextContent('/snippets')
|
||||
expect(screen.getByTestId('nav-is-app')).toHaveTextContent('false')
|
||||
expect(screen.getByTestId('nav-current')).toHaveTextContent('snippet-1:Snippet 1')
|
||||
expect(screen.getByTestId('nav-create-text')).toHaveTextContent('createFromBlank')
|
||||
expect(screen.getByText('Snippet 1 -> /snippets/snippet-1/orchestrate')).toBeInTheDocument()
|
||||
expect(screen.getByText('Snippet 2 -> /snippets/snippet-2/orchestrate')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show stale snippet detail as the current nav while switching snippets', () => {
|
||||
setupSnippetMocks()
|
||||
mockUseParams.mockReturnValue({ snippetId: 'snippet-2' } as ReturnType<typeof useParams>)
|
||||
mockUseSnippetApiDetail.mockReturnValue({
|
||||
data: mockSnippetData[0],
|
||||
} as ReturnType<typeof useSnippetApiDetail>)
|
||||
|
||||
render(<AppNav />)
|
||||
|
||||
expect(screen.getByTestId('nav-current')).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('should load more snippets from the snippet selector when more data is available', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { fetchNextPage } = setupSnippetMocks({ hasNextPage: true })
|
||||
|
||||
render(<AppNav />)
|
||||
|
||||
await user.click(screen.getByTestId('load-more'))
|
||||
expect(fetchNextPage).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should open the create snippet dialog from snippet nav create action', async () => {
|
||||
const user = userEvent.setup()
|
||||
const mutate = vi.fn()
|
||||
setupSnippetMocks({ mutate })
|
||||
|
||||
render(<AppNav />)
|
||||
|
||||
await user.click(screen.getByTestId('create-blank'))
|
||||
expect(screen.getByTestId('create-snippet-dialog')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByTestId('create-snippet-dialog'))
|
||||
expect(mutate).toHaveBeenCalledWith({
|
||||
body: {
|
||||
name: 'Created Snippet',
|
||||
description: undefined,
|
||||
icon_info: {
|
||||
icon: '🤖',
|
||||
icon_type: 'emoji',
|
||||
icon_background: '#fff',
|
||||
icon_url: undefined,
|
||||
},
|
||||
},
|
||||
}, expect.objectContaining({
|
||||
onSuccess: expect.any(Function),
|
||||
onError: expect.any(Function),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,19 +1,22 @@
|
||||
'use client'
|
||||
|
||||
import type { NavItem } from '../nav/nav-selector'
|
||||
import {
|
||||
RiRobot2Fill,
|
||||
RiRobot2Line,
|
||||
} from '@remixicon/react'
|
||||
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
|
||||
import { toast } from '@langgenius/dify-ui/toast'
|
||||
import { flatten } from 'es-toolkit/compat'
|
||||
import { produce } from 'immer'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import CreateSnippetDialog from '@/app/components/workflow/create-snippet-dialog'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { useParams } from '@/next/navigation'
|
||||
import { useParams, usePathname, useRouter } from '@/next/navigation'
|
||||
import { useInfiniteAppList } from '@/service/use-apps'
|
||||
import {
|
||||
useCreateSnippetMutation,
|
||||
useInfiniteSnippetList,
|
||||
useSnippetApiDetail,
|
||||
} from '@/service/use-snippets'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import Nav from '../nav'
|
||||
|
||||
@ -23,13 +26,18 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
|
||||
|
||||
const AppNav = () => {
|
||||
const { t } = useTranslation()
|
||||
const { appId } = useParams()
|
||||
const { appId, snippetId } = useParams()
|
||||
const { push } = useRouter()
|
||||
const pathname = usePathname()
|
||||
const isSnippetSegment = pathname === '/snippets' || pathname.startsWith('/snippets/')
|
||||
const currentSnippetId = typeof snippetId === 'string' ? snippetId : ''
|
||||
const { isCurrentWorkspaceEditor } = useAppContext()
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
const [showNewAppDialog, setShowNewAppDialog] = useState(false)
|
||||
const [showNewAppTemplateDialog, setShowNewAppTemplateDialog] = useState(false)
|
||||
const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false)
|
||||
const [navItems, setNavItems] = useState<NavItem[]>([])
|
||||
const [showCreateSnippetDialog, setShowCreateSnippetDialog] = useState(false)
|
||||
const createSnippetMutation = useCreateSnippetMutation()
|
||||
|
||||
const {
|
||||
data: appsData,
|
||||
@ -41,14 +49,36 @@ const AppNav = () => {
|
||||
page: 1,
|
||||
limit: 30,
|
||||
name: '',
|
||||
}, { enabled: !!appId })
|
||||
}, { enabled: !!appId && !isSnippetSegment })
|
||||
|
||||
const {
|
||||
data: snippetsData,
|
||||
fetchNextPage: fetchNextSnippetPage,
|
||||
hasNextPage: hasNextSnippetPage,
|
||||
isFetchingNextPage: isFetchingNextSnippetPage,
|
||||
} = useInfiniteSnippetList({
|
||||
page: 1,
|
||||
limit: 30,
|
||||
}, { enabled: !!currentSnippetId })
|
||||
|
||||
const { data: snippetDetail } = useSnippetApiDetail(currentSnippetId)
|
||||
|
||||
const handleLoadMore = useCallback(() => {
|
||||
if (hasNextPage)
|
||||
fetchNextPage()
|
||||
}, [fetchNextPage, hasNextPage])
|
||||
|
||||
const handleLoadMoreSnippet = useCallback(() => {
|
||||
if (hasNextSnippetPage)
|
||||
fetchNextSnippetPage()
|
||||
}, [fetchNextSnippetPage, hasNextSnippetPage])
|
||||
|
||||
const openModal = (state: string) => {
|
||||
if (isSnippetSegment) {
|
||||
setShowCreateSnippetDialog(true)
|
||||
return
|
||||
}
|
||||
|
||||
if (state === 'blank')
|
||||
setShowNewAppDialog(true)
|
||||
if (state === 'template')
|
||||
@ -57,64 +87,125 @@ const AppNav = () => {
|
||||
setShowCreateFromDSLModal(true)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (appsData) {
|
||||
const appItems = flatten((appsData.pages ?? []).map(appData => appData.data))
|
||||
const navItems = appItems.map((app) => {
|
||||
const link = ((isCurrentWorkspaceEditor, app) => {
|
||||
if (!isCurrentWorkspaceEditor) {
|
||||
return `/app/${app.id}/overview`
|
||||
}
|
||||
else {
|
||||
if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT)
|
||||
return `/app/${app.id}/workflow`
|
||||
else
|
||||
return `/app/${app.id}/configuration`
|
||||
}
|
||||
})(isCurrentWorkspaceEditor, app)
|
||||
return {
|
||||
id: app.id,
|
||||
icon_type: app.icon_type,
|
||||
icon: app.icon,
|
||||
icon_background: app.icon_background,
|
||||
icon_url: app.icon_url,
|
||||
name: app.name,
|
||||
mode: app.mode,
|
||||
link,
|
||||
}
|
||||
})
|
||||
setNavItems(navItems as any)
|
||||
}
|
||||
}, [appsData, isCurrentWorkspaceEditor, setNavItems])
|
||||
const appNavItems = useMemo<NavItem[]>(() => {
|
||||
if (!appsData)
|
||||
return []
|
||||
|
||||
// update current app name
|
||||
useEffect(() => {
|
||||
if (appDetail) {
|
||||
const newNavItems = produce(navItems, (draft: NavItem[]) => {
|
||||
navItems.forEach((app, index) => {
|
||||
if (app.id === appDetail.id)
|
||||
draft[index]!.name = appDetail.name
|
||||
})
|
||||
})
|
||||
setNavItems(newNavItems)
|
||||
const appItems = flatten((appsData.pages ?? []).map(appData => appData.data))
|
||||
|
||||
return appItems.map((app) => {
|
||||
const link = (() => {
|
||||
if (!isCurrentWorkspaceEditor)
|
||||
return `/app/${app.id}/overview`
|
||||
|
||||
if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT)
|
||||
return `/app/${app.id}/workflow`
|
||||
|
||||
return `/app/${app.id}/configuration`
|
||||
})()
|
||||
|
||||
return {
|
||||
id: app.id,
|
||||
icon_type: app.icon_type,
|
||||
icon: app.icon,
|
||||
icon_background: app.icon_background,
|
||||
icon_url: app.icon_url,
|
||||
name: appDetail?.id === app.id ? appDetail.name : app.name,
|
||||
mode: app.mode,
|
||||
link,
|
||||
}
|
||||
})
|
||||
}, [appDetail?.id, appDetail?.name, appsData, isCurrentWorkspaceEditor])
|
||||
|
||||
const snippetNavItems = useMemo<NavItem[]>(() => {
|
||||
if (!snippetsData)
|
||||
return []
|
||||
|
||||
const snippetItems = flatten((snippetsData.pages ?? []).map(snippetData => snippetData.data))
|
||||
|
||||
return snippetItems.map(snippet => ({
|
||||
id: snippet.id,
|
||||
icon_type: snippet.icon_info.icon_type,
|
||||
icon: snippet.icon_info.icon,
|
||||
icon_background: snippet.icon_info.icon_background ?? null,
|
||||
icon_url: snippet.icon_info.icon_url ?? null,
|
||||
name: snippet.name,
|
||||
link: `/snippets/${snippet.id}/orchestrate`,
|
||||
}))
|
||||
}, [snippetsData])
|
||||
|
||||
const currentSnippetNav = useMemo(() => {
|
||||
if (!snippetDetail)
|
||||
return
|
||||
|
||||
if (snippetDetail.id !== currentSnippetId)
|
||||
return
|
||||
|
||||
return {
|
||||
id: snippetDetail.id,
|
||||
icon_type: snippetDetail.icon_info.icon_type,
|
||||
icon: snippetDetail.icon_info.icon,
|
||||
icon_background: snippetDetail.icon_info.icon_background ?? null,
|
||||
icon_url: snippetDetail.icon_info.icon_url ?? null,
|
||||
name: snippetDetail.name,
|
||||
}
|
||||
}, [appDetail, navItems])
|
||||
}, [currentSnippetId, snippetDetail])
|
||||
|
||||
const handleCreateSnippet = useCallback(({
|
||||
name,
|
||||
description,
|
||||
icon,
|
||||
}: {
|
||||
name: string
|
||||
description: string
|
||||
icon: AppIconSelection
|
||||
}) => {
|
||||
createSnippetMutation.mutate({
|
||||
body: {
|
||||
name,
|
||||
description: description || undefined,
|
||||
icon_info: {
|
||||
icon: icon.type === 'emoji' ? icon.icon : icon.fileId,
|
||||
icon_type: icon.type,
|
||||
icon_background: icon.type === 'emoji' ? icon.background : undefined,
|
||||
icon_url: icon.type === 'image' ? icon.url : undefined,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
onSuccess: (snippet) => {
|
||||
toast.success(t('snippet.createSuccess', { ns: 'workflow' }))
|
||||
setShowCreateSnippetDialog(false)
|
||||
push(`/snippets/${snippet.id}/orchestrate`)
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.error(error instanceof Error ? error.message : t('createFailed', { ns: 'snippet' }))
|
||||
},
|
||||
})
|
||||
}, [createSnippetMutation, push, t])
|
||||
|
||||
const currentNav = isSnippetSegment ? currentSnippetNav : appDetail
|
||||
const currentNavigationItems = isSnippetSegment ? snippetNavItems : appNavItems
|
||||
const currentCreateText = isSnippetSegment
|
||||
? t('createFromBlank', { ns: 'snippet' })
|
||||
: t('menus.newApp', { ns: 'common' })
|
||||
const currentLoadMore = isSnippetSegment ? handleLoadMoreSnippet : handleLoadMore
|
||||
const currentIsLoadingMore = isSnippetSegment ? isFetchingNextSnippetPage : isFetchingNextPage
|
||||
|
||||
return (
|
||||
<>
|
||||
<Nav
|
||||
isApp
|
||||
icon={<RiRobot2Line className="h-4 w-4" />}
|
||||
activeIcon={<RiRobot2Fill className="h-4 w-4" />}
|
||||
isApp={!isSnippetSegment}
|
||||
icon={<span className="i-ri-robot-2-line h-4 w-4" />}
|
||||
activeIcon={<span className="i-ri-robot-2-fill h-4 w-4" />}
|
||||
text={t('menus.apps', { ns: 'common' })}
|
||||
activeSegment={['apps', 'app', 'snippets']}
|
||||
link="/apps"
|
||||
curNav={appDetail}
|
||||
navigationItems={navItems}
|
||||
createText={t('menus.newApp', { ns: 'common' })}
|
||||
link={isSnippetSegment ? '/snippets' : '/apps'}
|
||||
curNav={currentNav ?? undefined}
|
||||
navigationItems={currentNavigationItems}
|
||||
createText={currentCreateText}
|
||||
onCreate={openModal}
|
||||
onLoadMore={handleLoadMore}
|
||||
isLoadingMore={isFetchingNextPage}
|
||||
onLoadMore={currentLoadMore}
|
||||
isLoadingMore={currentIsLoadingMore}
|
||||
/>
|
||||
<CreateAppModal
|
||||
show={showNewAppDialog}
|
||||
@ -131,6 +222,14 @@ const AppNav = () => {
|
||||
onClose={() => setShowCreateFromDSLModal(false)}
|
||||
onSuccess={() => refetch()}
|
||||
/>
|
||||
{showCreateSnippetDialog && (
|
||||
<CreateSnippetDialog
|
||||
isOpen={showCreateSnippetDialog}
|
||||
isSubmitting={createSnippetMutation.isPending}
|
||||
onClose={() => setShowCreateSnippetDialog(false)}
|
||||
onConfirm={handleCreateSnippet}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -201,6 +201,14 @@ describe('Nav Component', () => {
|
||||
expect(mockSetAppDetail).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not call setAppDetail from snippets segment', () => {
|
||||
vi.mocked(useSelectedLayoutSegment).mockReturnValue('snippets')
|
||||
render(<Nav {...defaultProps} activeSegment={['apps', 'app', 'snippets']} />)
|
||||
const link = screen.getByRole('link')
|
||||
fireEvent.click(link.firstChild!)
|
||||
expect(mockSetAppDetail).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show ArrowNarrowLeft on hover when curNav is provided and activated', () => {
|
||||
const curNav = navigationItems[0]
|
||||
render(<Nav {...defaultProps} curNav={curNav} />)
|
||||
@ -238,19 +246,20 @@ describe('Nav Component', () => {
|
||||
})
|
||||
|
||||
it('should navigate when an item is selected', async () => {
|
||||
render(<Nav {...defaultProps} curNav={curNav} />)
|
||||
vi.mocked(useSelectedLayoutSegment).mockReturnValue('snippets')
|
||||
render(<Nav {...defaultProps} activeSegment={['apps', 'app', 'snippets']} curNav={curNav} />)
|
||||
const selectorButton = screen.getByRole('button', { name: /Item 1/i })
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(selectorButton)
|
||||
})
|
||||
mockSetAppDetail.mockClear()
|
||||
|
||||
const item2 = await screen.findByText('Item 2')
|
||||
await act(async () => {
|
||||
fireEvent.click(item2)
|
||||
})
|
||||
|
||||
expect(mockSetAppDetail).toHaveBeenCalled()
|
||||
expect(mockPush).toHaveBeenCalledWith('/item2')
|
||||
})
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import { cn } from '@langgenius/dify-ui/cn'
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
import Link from '@/next/link'
|
||||
import { useSelectedLayoutSegment } from '@/next/navigation'
|
||||
import NavSelector from './nav-selector'
|
||||
@ -51,6 +50,8 @@ const Nav = ({
|
||||
// Don't clear state if opening in new tab/window
|
||||
if (e.metaKey || e.ctrlKey || e.shiftKey || e.button !== 0)
|
||||
return
|
||||
if (segment === 'snippets')
|
||||
return
|
||||
setAppDetail()
|
||||
}}
|
||||
className={cn('flex h-7 cursor-pointer items-center rounded-[10px] px-2.5', isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text', curNav && isActivated && 'hover:bg-components-main-nav-nav-button-bg-active-hover')}
|
||||
@ -60,7 +61,7 @@ const Nav = ({
|
||||
<div>
|
||||
{
|
||||
(hovered && curNav)
|
||||
? <ArrowNarrowLeft className="h-4 w-4" />
|
||||
? <span className="i-custom-vender-line-arrows-arrow-narrow-left h-4 w-4" />
|
||||
: isActivated
|
||||
? activeIcon
|
||||
: icon
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import type { PluginStatus } from '@/app/components/plugins/types'
|
||||
import type { Locale } from '@/i18n-config'
|
||||
import { ScrollArea } from '@langgenius/dify-ui/scroll-area'
|
||||
import PluginItem from './plugin-item'
|
||||
|
||||
type PluginSectionProps = {
|
||||
@ -43,7 +44,14 @@ const PluginSection: FC<PluginSectionProps> = ({
|
||||
)
|
||||
{headerAction}
|
||||
</div>
|
||||
<div className="max-h-[300px] overflow-x-hidden overflow-y-auto [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
<ScrollArea
|
||||
className="max-h-[300px] overflow-hidden"
|
||||
label={title}
|
||||
slotClassNames={{
|
||||
viewport: 'overscroll-contain',
|
||||
content: 'min-w-0',
|
||||
}}
|
||||
>
|
||||
{plugins.map(plugin => (
|
||||
<PluginItem
|
||||
key={plugin.plugin_unique_identifier}
|
||||
@ -59,7 +67,7 @@ const PluginSection: FC<PluginSectionProps> = ({
|
||||
: undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { PluginStatus } from '@/app/components/plugins/types'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { ScrollArea } from '@langgenius/dify-ui/scroll-area'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import ErrorPluginItem from './error-plugin-item'
|
||||
@ -86,7 +87,14 @@ const PluginTaskList: FC<PluginTaskListProps> = ({
|
||||
{t('task.clearAll', { ns: 'plugin' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="max-h-[300px] overflow-x-hidden overflow-y-auto [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
<ScrollArea
|
||||
className="max-h-[300px] overflow-hidden"
|
||||
label={t('task.installedError', { ns: 'plugin', errorLength: errorPlugins.length })}
|
||||
slotClassNames={{
|
||||
viewport: 'overscroll-contain',
|
||||
content: 'min-w-0',
|
||||
}}
|
||||
>
|
||||
{errorPlugins.map(plugin => (
|
||||
<ErrorPluginItem
|
||||
key={plugin.plugin_unique_identifier}
|
||||
@ -96,7 +104,7 @@ const PluginTaskList: FC<PluginTaskListProps> = ({
|
||||
onClear={() => onClearSingle(plugin.taskId, plugin.plugin_unique_identifier)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -117,7 +117,7 @@ const PluginTasks = () => {
|
||||
<DropdownMenuContent
|
||||
placement="bottom-end"
|
||||
sideOffset={4}
|
||||
popupClassName="[scrollbar-width:none] overflow-visible border-0 bg-transparent p-0 shadow-none backdrop-blur-none [&::-webkit-scrollbar]:hidden"
|
||||
popupClassName="overflow-visible border-0 bg-transparent p-0 shadow-none backdrop-blur-none"
|
||||
>
|
||||
<PluginTaskList
|
||||
runningPlugins={runningPlugins}
|
||||
|
||||
@ -2,6 +2,7 @@ import type { AutoUpdateConfig } from '../types'
|
||||
import type { PluginDeclaration, PluginDetail } from '@/app/components/plugins/types'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import dayjs from 'dayjs'
|
||||
import timezone from 'dayjs/plugin/timezone'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
@ -803,165 +804,103 @@ describe('auto-update-setting', () => {
|
||||
})
|
||||
|
||||
describe('StrategyPicker (strategy-picker.tsx)', () => {
|
||||
const defaultProps = {
|
||||
value: AUTO_UPDATE_STRATEGY.disabled,
|
||||
onChange: vi.fn(),
|
||||
const i18nKeyByStrategy: Record<AUTO_UPDATE_STRATEGY, 'disabled' | 'fixOnly' | 'latest'> = {
|
||||
[AUTO_UPDATE_STRATEGY.disabled]: 'disabled',
|
||||
[AUTO_UPDATE_STRATEGY.fixOnly]: 'fixOnly',
|
||||
[AUTO_UPDATE_STRATEGY.latest]: 'latest',
|
||||
}
|
||||
|
||||
const triggerName = (strategy: AUTO_UPDATE_STRATEGY) =>
|
||||
new RegExp(`plugin\\.autoUpdate\\.strategy\\.${i18nKeyByStrategy[strategy]}\\.name`, 'i')
|
||||
|
||||
const findOption = async (key: 'disabled' | 'fixOnly' | 'latest') => {
|
||||
const options = await screen.findAllByRole('menuitemradio')
|
||||
const option = options.find(item =>
|
||||
item.textContent?.includes(`plugin.autoUpdate.strategy.${key}.name`),
|
||||
)
|
||||
if (!option)
|
||||
throw new Error(`Strategy option "${key}" not found`)
|
||||
return option
|
||||
}
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render trigger button with current strategy label', () => {
|
||||
// Act
|
||||
render(<StrategyPicker {...defaultProps} value={AUTO_UPDATE_STRATEGY.disabled} />)
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={vi.fn()} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.disabled\.name/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render dropdown content when closed', () => {
|
||||
// Act
|
||||
render(<StrategyPicker {...defaultProps} />)
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={vi.fn()} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('menu')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render all strategy options when open', () => {
|
||||
// Arrange
|
||||
mockPortalOpen = true
|
||||
it('should render all strategy options when open', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={vi.fn()} />)
|
||||
|
||||
// Act
|
||||
render(<StrategyPicker {...defaultProps} />)
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) }))
|
||||
|
||||
// Wait for portal to open
|
||||
if (mockPortalOpen) {
|
||||
// Assert all options visible (use getAllByText for strategy name as it appears in both trigger and dropdown)
|
||||
expect(screen.getAllByText('plugin.autoUpdate.strategy.disabled.name').length).toBeGreaterThanOrEqual(1)
|
||||
expect(screen.getByText('plugin.autoUpdate.strategy.fixOnly.name')).toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.autoUpdate.strategy.latest.name')).toBeInTheDocument()
|
||||
}
|
||||
const options = await screen.findAllByRole('menuitemradio')
|
||||
expect(options).toHaveLength(3)
|
||||
expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.disabled.name'))).toBe(true)
|
||||
expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.fixOnly.name'))).toBe(true)
|
||||
expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.latest.name'))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should toggle dropdown when trigger is clicked', () => {
|
||||
// Act
|
||||
render(<StrategyPicker {...defaultProps} />)
|
||||
|
||||
// Assert - initially closed
|
||||
expect(mockPortalOpen).toBe(false)
|
||||
|
||||
// Act - click trigger
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
|
||||
// Assert - portal trigger element should still be in document
|
||||
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onChange with fixOnly when Bug Fixes Only option is clicked', () => {
|
||||
// Arrange - force portal content to be visible for testing option selection
|
||||
forcePortalContentVisible = true
|
||||
const onChange = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={onChange} />)
|
||||
|
||||
// Find and click the "Bug Fixes Only" option
|
||||
const fixOnlyOption = screen.getByText('plugin.autoUpdate.strategy.fixOnly.name').closest('div[class*="cursor-pointer"]')
|
||||
expect(fixOnlyOption).toBeInTheDocument()
|
||||
fireEvent.click(fixOnlyOption!)
|
||||
|
||||
// Assert
|
||||
expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.fixOnly)
|
||||
})
|
||||
|
||||
it('should call onChange with latest when Latest Version option is clicked', () => {
|
||||
// Arrange - force portal content to be visible for testing option selection
|
||||
forcePortalContentVisible = true
|
||||
const onChange = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={onChange} />)
|
||||
|
||||
// Find and click the "Latest Version" option
|
||||
const latestOption = screen.getByText('plugin.autoUpdate.strategy.latest.name').closest('div[class*="cursor-pointer"]')
|
||||
expect(latestOption).toBeInTheDocument()
|
||||
fireEvent.click(latestOption!)
|
||||
|
||||
// Assert
|
||||
expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.latest)
|
||||
})
|
||||
|
||||
it('should call onChange with disabled when Disabled option is clicked', () => {
|
||||
// Arrange - force portal content to be visible for testing option selection
|
||||
forcePortalContentVisible = true
|
||||
const onChange = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.fixOnly} onChange={onChange} />)
|
||||
|
||||
// Find and click the "Disabled" option - need to find the one in the dropdown, not the button
|
||||
const disabledOptions = screen.getAllByText('plugin.autoUpdate.strategy.disabled.name')
|
||||
// The second one should be in the dropdown
|
||||
const dropdownOption = disabledOptions.find(el => el.closest('div[class*="cursor-pointer"]'))
|
||||
expect(dropdownOption).toBeInTheDocument()
|
||||
fireEvent.click(dropdownOption!.closest('div[class*="cursor-pointer"]')!)
|
||||
|
||||
// Assert
|
||||
expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.disabled)
|
||||
})
|
||||
|
||||
it('should stop event propagation when option is clicked', () => {
|
||||
// Arrange - force portal content to be visible
|
||||
forcePortalContentVisible = true
|
||||
const onChange = vi.fn()
|
||||
const parentClickHandler = vi.fn()
|
||||
|
||||
// Act
|
||||
render(
|
||||
<div onClick={parentClickHandler}>
|
||||
<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={onChange} />
|
||||
</div>,
|
||||
)
|
||||
|
||||
// Click an option
|
||||
const fixOnlyOption = screen.getByText('plugin.autoUpdate.strategy.fixOnly.name').closest('div[class*="cursor-pointer"]')
|
||||
fireEvent.click(fixOnlyOption!)
|
||||
|
||||
// Assert - onChange is called but parent click handler should not propagate
|
||||
expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.fixOnly)
|
||||
})
|
||||
|
||||
it('should render check icon for currently selected option', () => {
|
||||
// Arrange - force portal content to be visible
|
||||
forcePortalContentVisible = true
|
||||
|
||||
// Act - render with fixOnly selected
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.fixOnly} onChange={vi.fn()} />)
|
||||
|
||||
// Assert - RiCheckLine should be rendered (check icon)
|
||||
// Find all "Bug Fixes Only" texts and get the one in the dropdown (has cursor-pointer parent)
|
||||
const allFixOnlyTexts = screen.getAllByText('plugin.autoUpdate.strategy.fixOnly.name')
|
||||
const dropdownOption = allFixOnlyTexts.find(el => el.closest('div[class*="cursor-pointer"]'))
|
||||
const optionContainer = dropdownOption?.closest('div[class*="cursor-pointer"]')
|
||||
expect(optionContainer).toBeInTheDocument()
|
||||
// The check icon SVG should exist within the option
|
||||
expect(optionContainer?.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render check icon for non-selected options', () => {
|
||||
// Arrange - force portal content to be visible
|
||||
forcePortalContentVisible = true
|
||||
|
||||
// Act - render with disabled selected
|
||||
it('should open and close the menu when the trigger is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.disabled} onChange={vi.fn()} />)
|
||||
|
||||
// Assert - check the Latest Version option should not have check icon
|
||||
const latestOption = screen.getByText('plugin.autoUpdate.strategy.latest.name').closest('div[class*="cursor-pointer"]')
|
||||
// The svg should only be in selected option, not in non-selected
|
||||
const checkIconContainer = latestOption?.querySelector('div.mr-1')
|
||||
// Non-selected option should have empty check icon container
|
||||
expect(checkIconContainer?.querySelector('svg')).toBeNull()
|
||||
const trigger = screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) })
|
||||
expect(screen.queryByRole('menu')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(trigger)
|
||||
expect(await screen.findByRole('menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it.each<[AUTO_UPDATE_STRATEGY, 'disabled' | 'fixOnly' | 'latest', AUTO_UPDATE_STRATEGY]>([
|
||||
[AUTO_UPDATE_STRATEGY.disabled, 'fixOnly', AUTO_UPDATE_STRATEGY.fixOnly],
|
||||
[AUTO_UPDATE_STRATEGY.disabled, 'latest', AUTO_UPDATE_STRATEGY.latest],
|
||||
[AUTO_UPDATE_STRATEGY.fixOnly, 'disabled', AUTO_UPDATE_STRATEGY.disabled],
|
||||
])('should call onChange with %s -> %s when option is selected', async (initial, optionKey, expected) => {
|
||||
const user = userEvent.setup()
|
||||
const onChange = vi.fn()
|
||||
render(<StrategyPicker value={initial} onChange={onChange} />)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: triggerName(initial) }))
|
||||
await user.click(await findOption(optionKey))
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith(expected)
|
||||
})
|
||||
|
||||
it('should mark only the currently selected option with aria-checked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.fixOnly} onChange={vi.fn()} />)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.fixOnly) }))
|
||||
|
||||
const options = await screen.findAllByRole('menuitemradio')
|
||||
const checked = options.filter(o => o.getAttribute('aria-checked') === 'true')
|
||||
|
||||
expect(checked).toHaveLength(1)
|
||||
expect(checked[0]).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name')
|
||||
})
|
||||
|
||||
it('should render the check indicator inside the selected option only', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<StrategyPicker value={AUTO_UPDATE_STRATEGY.fixOnly} onChange={vi.fn()} />)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.fixOnly) }))
|
||||
|
||||
const fixOnlyOption = await findOption('fixOnly')
|
||||
const latestOption = await findOption('latest')
|
||||
|
||||
expect(fixOnlyOption.querySelector('.i-ri-check-line')).toBeInTheDocument()
|
||||
expect(latestOption.querySelector('.i-ri-check-line')).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1280,7 +1219,9 @@ describe('auto-update-setting', () => {
|
||||
render(<AutoUpdateSetting {...defaultProps} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.fixOnly\.name/i }),
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show time picker when strategy is not disabled', () => {
|
||||
@ -1407,16 +1348,27 @@ describe('auto-update-setting', () => {
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call onChange with updated strategy when strategy changes', () => {
|
||||
it('should call onChange with updated strategy when strategy changes', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const onChange = vi.fn()
|
||||
const payload = createMockAutoUpdateConfig()
|
||||
const payload = createMockAutoUpdateConfig({ strategy_setting: AUTO_UPDATE_STRATEGY.fixOnly })
|
||||
|
||||
// Act
|
||||
render(<AutoUpdateSetting payload={payload} onChange={onChange} />)
|
||||
|
||||
// Assert - component renders with strategy picker
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.fixOnly\.name/i }),
|
||||
)
|
||||
const latestOption = (await screen.findAllByRole('menuitemradio')).find(item =>
|
||||
item.textContent?.includes('plugin.autoUpdate.strategy.latest.name'),
|
||||
)!
|
||||
await user.click(latestOption)
|
||||
|
||||
// Assert
|
||||
expect(onChange).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ strategy_setting: AUTO_UPDATE_STRATEGY.latest }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should call onChange with updated time when time changes', () => {
|
||||
|
||||
@ -1,62 +1,12 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import StrategyPicker from '../strategy-picker'
|
||||
import { AUTO_UPDATE_STRATEGY } from '../types'
|
||||
|
||||
let portalOpen = false
|
||||
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({
|
||||
Button: ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
}) => <span data-testid="picker-button">{children}</span>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
|
||||
const _React = await import('react')
|
||||
return {
|
||||
PortalToFollowElem: ({
|
||||
open,
|
||||
children,
|
||||
}: {
|
||||
open: boolean
|
||||
children: React.ReactNode
|
||||
}) => {
|
||||
portalOpen = open
|
||||
return <div>{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
onClick: (event: { stopPropagation: () => void, nativeEvent: { stopImmediatePropagation: () => void } }) => void
|
||||
}) => (
|
||||
<button
|
||||
data-testid="trigger"
|
||||
onClick={() => onClick({
|
||||
stopPropagation: vi.fn(),
|
||||
nativeEvent: { stopImmediatePropagation: vi.fn() },
|
||||
})}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
),
|
||||
PortalToFollowElemContent: ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
}) => portalOpen ? <div data-testid="portal-content">{children}</div> : null,
|
||||
}
|
||||
})
|
||||
const triggerName = (key: string) => new RegExp(`plugin\\.autoUpdate\\.strategy\\.${key}\\.name`, 'i')
|
||||
|
||||
describe('StrategyPicker', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
portalOpen = false
|
||||
})
|
||||
|
||||
it('renders the selected strategy label in the trigger', () => {
|
||||
render(
|
||||
<StrategyPicker
|
||||
@ -65,10 +15,12 @@ describe('StrategyPicker', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('trigger')).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name')
|
||||
expect(screen.getByRole('button', { name: triggerName('fixOnly') })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('menu')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens the option list when the trigger is clicked', () => {
|
||||
it('opens the option list when the trigger is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<StrategyPicker
|
||||
value={AUTO_UPDATE_STRATEGY.disabled}
|
||||
@ -76,14 +28,33 @@ describe('StrategyPicker', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('trigger'))
|
||||
await user.click(screen.getByRole('button', { name: triggerName('disabled') }))
|
||||
|
||||
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('portal-content').querySelectorAll('svg')).toHaveLength(1)
|
||||
const options = await screen.findAllByRole('menuitemradio')
|
||||
expect(options).toHaveLength(3)
|
||||
expect(screen.getByText('plugin.autoUpdate.strategy.latest.description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calls onChange when a new strategy is selected', () => {
|
||||
it('marks only the currently selected strategy as checked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<StrategyPicker
|
||||
value={AUTO_UPDATE_STRATEGY.fixOnly}
|
||||
onChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: triggerName('fixOnly') }))
|
||||
|
||||
const checkedOptions = (await screen.findAllByRole('menuitemradio'))
|
||||
.filter(item => item.getAttribute('aria-checked') === 'true')
|
||||
|
||||
expect(checkedOptions).toHaveLength(1)
|
||||
expect(checkedOptions[0]).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name')
|
||||
})
|
||||
|
||||
it('calls onChange and closes the menu when a new strategy is selected', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChange = vi.fn()
|
||||
render(
|
||||
<StrategyPicker
|
||||
@ -92,9 +63,12 @@ describe('StrategyPicker', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('trigger'))
|
||||
fireEvent.click(screen.getByText('plugin.autoUpdate.strategy.latest.name'))
|
||||
await user.click(screen.getByRole('button', { name: triggerName('disabled') }))
|
||||
const latestOption = (await screen.findAllByRole('menuitemradio'))
|
||||
.find(item => item.textContent?.includes('plugin.autoUpdate.strategy.latest.name'))!
|
||||
await user.click(latestOption)
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.latest)
|
||||
expect(await screen.findByRole('button', { name: triggerName('disabled') })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -105,7 +105,7 @@ const AutoUpdateSetting: FC<Props> = ({
|
||||
const renderTimePickerTrigger = useCallback(({ inputElem, onClick, isOpen }: TriggerParams) => {
|
||||
return (
|
||||
<div
|
||||
className="group float-right flex h-8 w-[160px] cursor-pointer items-center justify-between rounded-lg bg-components-input-bg-normal px-2 hover:bg-state-base-hover-alt"
|
||||
className="group flex h-8 w-[160px] cursor-pointer items-center justify-between rounded-lg bg-components-input-bg-normal px-2 hover:bg-state-base-hover-alt"
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="flex w-0 grow items-center gap-x-1">
|
||||
@ -137,7 +137,7 @@ const AutoUpdateSetting: FC<Props> = ({
|
||||
<>
|
||||
<div className="flex items-center justify-between">
|
||||
<Label label={t(`${i18nPrefix}.updateTime`, { ns: 'plugin' })} />
|
||||
<div className="flex flex-col justify-start">
|
||||
<div className="flex flex-col items-end">
|
||||
<TimePicker
|
||||
value={timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(upgrade_time_of_day, timezone!))}
|
||||
timezone={timezone}
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import {
|
||||
RiArrowDownSLine,
|
||||
RiCheckLine,
|
||||
} from '@remixicon/react'
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuRadioItem,
|
||||
DropdownMenuRadioItemIndicator,
|
||||
DropdownMenuTrigger,
|
||||
} from '@langgenius/dify-ui/dropdown-menu'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import { AUTO_UPDATE_STRATEGY } from './types'
|
||||
|
||||
const i18nPrefix = 'autoUpdate.strategy'
|
||||
@ -42,58 +41,48 @@ const StrategyPicker = ({
|
||||
},
|
||||
]
|
||||
const selectedOption = options.find(option => option.value === value)
|
||||
const handleValueChange = (nextValue: string) => {
|
||||
onChange(nextValue as AUTO_UPDATE_STRATEGY)
|
||||
setOpen(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
<DropdownMenu
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement="top-end"
|
||||
offset={4}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.nativeEvent.stopImmediatePropagation()
|
||||
setOpen(v => !v)
|
||||
}}
|
||||
<DropdownMenuTrigger render={<Button size="small" />}>
|
||||
{selectedOption?.label}
|
||||
<span aria-hidden className="i-ri-arrow-down-s-line h-3.5 w-3.5" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
placement="top-end"
|
||||
sideOffset={4}
|
||||
className="z-99"
|
||||
popupClassName="w-[280px] p-1"
|
||||
>
|
||||
<Button
|
||||
size="small"
|
||||
<DropdownMenuRadioGroup
|
||||
value={value}
|
||||
onValueChange={handleValueChange}
|
||||
>
|
||||
{selectedOption?.label}
|
||||
<RiArrowDownSLine className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className="z-99">
|
||||
<div className="w-[280px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg">
|
||||
{
|
||||
options.map(option => (
|
||||
<div
|
||||
key={option.value}
|
||||
className="flex cursor-pointer rounded-lg p-2 pr-3 hover:bg-state-base-hover"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.nativeEvent.stopImmediatePropagation()
|
||||
onChange(option.value)
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
<div className="mr-1 w-4 shrink-0">
|
||||
{
|
||||
value === option.value && (
|
||||
<RiCheckLine className="h-4 w-4 text-text-accent" />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="mb-0.5 system-sm-semibold text-text-secondary">{option.label}</div>
|
||||
<div className="system-xs-regular text-text-tertiary">{option.description}</div>
|
||||
</div>
|
||||
{options.map(option => (
|
||||
<DropdownMenuRadioItem
|
||||
key={option.value}
|
||||
value={option.value}
|
||||
className="mx-0 h-auto items-start gap-1 p-2 pr-3"
|
||||
>
|
||||
<div className="mr-1 flex w-4 shrink-0 justify-center pt-0.5">
|
||||
<DropdownMenuRadioItemIndicator className="ml-0" />
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
<div className="grow">
|
||||
<div className="mb-0.5 system-sm-semibold text-text-secondary">{option.label}</div>
|
||||
<div className="system-xs-regular text-text-tertiary">{option.description}</div>
|
||||
</div>
|
||||
</DropdownMenuRadioItem>
|
||||
))}
|
||||
</DropdownMenuRadioGroup>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +91,6 @@ const mockSnippetDetail: SnippetDetailPayload = {
|
||||
id: 'snippet-1',
|
||||
name: 'Tone Rewriter',
|
||||
description: 'A static snippet mock.',
|
||||
author: 'Evan',
|
||||
updatedAt: 'Updated 2h ago',
|
||||
usage: 'Used 19 times',
|
||||
icon: '🪄',
|
||||
|
||||
@ -79,7 +79,6 @@ const mockSnippetDetail: SnippetDetailPayload = {
|
||||
id: 'snippet-1',
|
||||
name: 'Tone Rewriter',
|
||||
description: 'A static snippet mock.',
|
||||
author: 'Evan',
|
||||
updatedAt: '2024-03-24',
|
||||
usage: '19',
|
||||
icon: '🪄',
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
import type { SnippetListItem } from '@/types/snippet'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import SnippetCard from '../snippet-card'
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useMembers: () => ({
|
||||
data: {
|
||||
accounts: [
|
||||
{ id: 'creator-id', name: 'Creator', email: 'creator@example.com', avatar: '', avatar_url: null, role: 'editor', last_login_at: '', created_at: '', status: 'active' },
|
||||
{ id: 'updater-id', name: 'Updater', email: 'updater@example.com', avatar: '', avatar_url: null, role: 'editor', last_login_at: '', created_at: '', status: 'active' },
|
||||
],
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/time', () => ({
|
||||
formatTime: () => 'formatted-time',
|
||||
}))
|
||||
|
||||
const createSnippet = (overrides: Partial<SnippetListItem> = {}): SnippetListItem => ({
|
||||
id: 'snippet-1',
|
||||
name: 'Tone Rewriter',
|
||||
description: 'Rewrites rough drafts.',
|
||||
type: 'node',
|
||||
is_published: true,
|
||||
use_count: 19,
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '🪄',
|
||||
icon_background: '#E0EAFF',
|
||||
icon_url: '',
|
||||
},
|
||||
created_at: 1_704_067_200,
|
||||
created_by: 'creator-id',
|
||||
updated_at: 1_704_153_600,
|
||||
updated_by: 'updater-id',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('SnippetCard', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render updater name and updated time from member data', () => {
|
||||
render(<SnippetCard snippet={createSnippet()} />)
|
||||
|
||||
expect(screen.getByText('Tone Rewriter')).toBeInTheDocument()
|
||||
expect(screen.getByText('snippet.updatedBy:{"name":"Updater","time":"formatted-time"}')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Creator')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fall back to creator name when updater is unavailable', () => {
|
||||
render(<SnippetCard snippet={createSnippet({ updated_by: 'missing-user' })} />)
|
||||
|
||||
expect(screen.getByText('snippet.updatedBy:{"name":"Creator","time":"formatted-time"}')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -66,7 +66,6 @@ const createSnippet = (overrides: Partial<SnippetDetail> = {}): SnippetDetail =>
|
||||
id: 'snippet-1',
|
||||
name: 'Snippet Title',
|
||||
description: 'Snippet description',
|
||||
author: 'tester',
|
||||
updatedAt: '2026-04-15',
|
||||
usage: '42',
|
||||
icon: 'emoji',
|
||||
|
||||
@ -23,6 +23,7 @@ const mockHandleRun = vi.fn()
|
||||
const mockHandleStartWorkflowRun = vi.fn()
|
||||
const mockHandleStopRun = vi.fn()
|
||||
const mockHandleWorkflowStartRunInWorkflow = vi.fn()
|
||||
const mockHandleCheckBeforePublish = vi.fn()
|
||||
const mockInspectVarsCrud = {
|
||||
hasNodeInspectVars: vi.fn(),
|
||||
hasSetInspectVar: vi.fn(),
|
||||
@ -81,6 +82,12 @@ vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () =>
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks/use-checklist', () => ({
|
||||
useChecklistBeforePublish: () => ({
|
||||
handleCheckBeforePublish: mockHandleCheckBeforePublish,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/snippets/hooks/use-inspect-vars-crud', () => ({
|
||||
useInspectVarsCrud: () => mockInspectVarsCrud,
|
||||
}))
|
||||
@ -165,7 +172,6 @@ const payload: SnippetDetailPayload = {
|
||||
id: 'snippet-1',
|
||||
name: 'Snippet',
|
||||
description: 'desc',
|
||||
author: '',
|
||||
updatedAt: '2026-03-29 10:00',
|
||||
usage: '0',
|
||||
icon: '',
|
||||
@ -208,6 +214,7 @@ describe('SnippetMain', () => {
|
||||
vi.clearAllMocks()
|
||||
mockSyncInputFieldsDraft.mockResolvedValue(undefined)
|
||||
mockPublishSnippetMutateAsync.mockResolvedValue({ created_at: 1_744_000_000 })
|
||||
mockHandleCheckBeforePublish.mockResolvedValue(true)
|
||||
capturedHooksStore = undefined
|
||||
snippetDetailStoreState = {
|
||||
editingField: null,
|
||||
|
||||
@ -7,6 +7,7 @@ const mockSetPublishMenuOpen = vi.fn()
|
||||
const mockUseKeyPress = vi.fn()
|
||||
const mockSetPublishedAt = vi.fn()
|
||||
const mockSetQueryData = vi.fn()
|
||||
const mockHandleCheckBeforePublish = vi.fn<() => Promise<boolean>>()
|
||||
|
||||
let isPublishMenuOpen = false
|
||||
let isPending = false
|
||||
@ -44,6 +45,12 @@ vi.mock('@/app/components/workflow/store', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks/use-checklist', () => ({
|
||||
useChecklistBeforePublish: () => ({
|
||||
handleCheckBeforePublish: mockHandleCheckBeforePublish,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../store', () => ({
|
||||
useSnippetDetailStore: (selector: (state: {
|
||||
isPublishMenuOpen: boolean
|
||||
@ -60,6 +67,7 @@ describe('useSnippetPublish', () => {
|
||||
isPublishMenuOpen = false
|
||||
isPending = false
|
||||
shortcutHandler = undefined
|
||||
mockHandleCheckBeforePublish.mockResolvedValue(true)
|
||||
mockMutateAsync.mockResolvedValue({ created_at: 1_712_345_678 })
|
||||
mockUseKeyPress.mockImplementation((_key, handler) => {
|
||||
shortcutHandler = handler
|
||||
@ -76,17 +84,39 @@ describe('useSnippetPublish', () => {
|
||||
await result.current.handlePublish()
|
||||
})
|
||||
|
||||
expect(mockHandleCheckBeforePublish).toHaveBeenCalledTimes(1)
|
||||
expect(mockMutateAsync).toHaveBeenCalledWith({
|
||||
params: { snippetId: 'snippet-1' },
|
||||
})
|
||||
expect(mockSetQueryData).toHaveBeenCalledTimes(1)
|
||||
const updateSnippetDetail = mockSetQueryData.mock.calls[0]![1] as (old: { is_published: boolean }) => { is_published: boolean }
|
||||
const setQueryDataCall = mockSetQueryData.mock.calls[0]
|
||||
expect(setQueryDataCall).toBeDefined()
|
||||
const updateSnippetDetail = setQueryDataCall![1] as (old: { is_published: boolean }) => { is_published: boolean }
|
||||
expect(updateSnippetDetail({ is_published: false })).toEqual({ is_published: true })
|
||||
expect(mockSetPublishedAt).toHaveBeenCalledWith(1_712_345_678)
|
||||
expect(mockSetPublishMenuOpen).toHaveBeenCalledWith(false)
|
||||
expect(toast.success).toHaveBeenCalledWith('snippet.publishSuccess')
|
||||
})
|
||||
|
||||
it('should not publish the snippet when checklist validation fails', async () => {
|
||||
mockHandleCheckBeforePublish.mockResolvedValue(false)
|
||||
|
||||
const { result } = renderHook(() => useSnippetPublish({
|
||||
snippetId: 'snippet-1',
|
||||
}))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handlePublish()
|
||||
})
|
||||
|
||||
expect(mockHandleCheckBeforePublish).toHaveBeenCalledTimes(1)
|
||||
expect(mockMutateAsync).not.toHaveBeenCalled()
|
||||
expect(mockSetQueryData).not.toHaveBeenCalled()
|
||||
expect(mockSetPublishedAt).not.toHaveBeenCalled()
|
||||
expect(mockSetPublishMenuOpen).not.toHaveBeenCalled()
|
||||
expect(toast.success).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should surface publish errors through toast feedback', async () => {
|
||||
mockMutateAsync.mockRejectedValue(new Error('publish failed'))
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import { useKeyPress } from 'ahooks'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useChecklistBeforePublish } from '@/app/components/workflow/hooks/use-checklist'
|
||||
import { useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import { getKeyboardKeyCodeBySystem } from '@/app/components/workflow/utils'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
@ -22,6 +23,7 @@ export const useSnippetPublish = ({
|
||||
const workflowStore = useWorkflowStore()
|
||||
const queryClient = useQueryClient()
|
||||
const publishSnippetMutation = usePublishSnippetWorkflowMutation(snippetId)
|
||||
const { handleCheckBeforePublish } = useChecklistBeforePublish()
|
||||
const {
|
||||
isPublishMenuOpen,
|
||||
setPublishMenuOpen,
|
||||
@ -32,6 +34,10 @@ export const useSnippetPublish = ({
|
||||
|
||||
const handlePublish = useCallback(async () => {
|
||||
try {
|
||||
const canPublish = await handleCheckBeforePublish()
|
||||
if (!canPublish)
|
||||
return
|
||||
|
||||
const publishedWorkflow = await publishSnippetMutation.mutateAsync({
|
||||
params: { snippetId },
|
||||
})
|
||||
@ -50,7 +56,7 @@ export const useSnippetPublish = ({
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : t('publishFailed'))
|
||||
}
|
||||
}, [publishSnippetMutation, queryClient, setPublishMenuOpen, snippetId, t, workflowStore])
|
||||
}, [handleCheckBeforePublish, publishSnippetMutation, queryClient, setPublishMenuOpen, snippetId, t, workflowStore])
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (event) => {
|
||||
if (publishSnippetMutation.isPending)
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
'use client'
|
||||
|
||||
import type { SnippetListItem } from '@/types/snippet'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Link from '@/next/link'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { formatTime } from '@/utils/time'
|
||||
|
||||
type Props = {
|
||||
snippet: SnippetListItem
|
||||
@ -11,16 +14,35 @@ type Props = {
|
||||
|
||||
const SnippetCard = ({ snippet }: Props) => {
|
||||
const { t } = useTranslation('snippet')
|
||||
const { data: membersData } = useMembers()
|
||||
|
||||
const memberNameById = useMemo(() => {
|
||||
return new Map((membersData?.accounts ?? []).map(member => [member.id, member.name]))
|
||||
}, [membersData?.accounts])
|
||||
|
||||
const updatedByName = memberNameById.get(snippet.updated_by)
|
||||
|| memberNameById.get(snippet.created_by)
|
||||
|| t('unknownUser')
|
||||
|
||||
const updatedAt = snippet.updated_at || snippet.created_at
|
||||
const updatedAtText = formatTime({
|
||||
date: (updatedAt > 1_000_000_000_000 ? updatedAt : updatedAt * 1000),
|
||||
dateFormat: `${t('segment.dateTimeFormat', { ns: 'datasetDocuments' })}`,
|
||||
})
|
||||
const updatedText = t('updatedBy', {
|
||||
name: updatedByName,
|
||||
time: updatedAtText,
|
||||
})
|
||||
|
||||
return (
|
||||
<Link href={`/snippets/${snippet.id}/orchestrate`} className="group col-span-1">
|
||||
<article className="relative inline-flex h-[160px] w-full flex-col rounded-xl border border-components-card-border bg-components-card-bg shadow-sm transition-all duration-200 ease-in-out hover:-translate-y-0.5 hover:shadow-lg">
|
||||
{!snippet.is_published && (
|
||||
<div className="absolute right-0 top-0 rounded-bl-lg rounded-tr-xl bg-background-default-dimmed px-2 py-1 text-[10px] font-medium uppercase leading-3 text-text-placeholder">
|
||||
Draft
|
||||
<div className="absolute top-0 right-0 rounded-tr-xl rounded-bl-lg bg-background-default-dimmed px-2 py-1 text-[10px] leading-3 font-medium text-text-placeholder uppercase">
|
||||
{t('draft')}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex h-[66px] items-center gap-3 px-[14px] pb-3 pt-[14px]">
|
||||
<div className="flex h-[66px] items-center gap-3 px-[14px] pt-[14px] pb-3">
|
||||
<AppIcon
|
||||
size="large"
|
||||
iconType={snippet.icon_info.icon_type}
|
||||
@ -28,8 +50,8 @@ const SnippetCard = ({ snippet }: Props) => {
|
||||
background={snippet.icon_info.icon_background}
|
||||
imageUrl={snippet.icon_info.icon_url}
|
||||
/>
|
||||
<div className="w-0 grow py-[1px]">
|
||||
<div className="truncate text-sm font-semibold leading-5 text-text-secondary" title={snippet.name}>
|
||||
<div className="w-0 grow py-px">
|
||||
<div className="truncate text-sm leading-5 font-semibold text-text-secondary" title={snippet.name}>
|
||||
{snippet.name}
|
||||
</div>
|
||||
</div>
|
||||
@ -39,11 +61,9 @@ const SnippetCard = ({ snippet }: Props) => {
|
||||
{snippet.description}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-auto flex items-center gap-1 px-[14px] pb-3 pt-2 text-xs leading-4 text-text-tertiary">
|
||||
<span className="truncate">{snippet.author}</span>
|
||||
<span>·</span>
|
||||
<span className="truncate">{snippet.updated_at}</span>
|
||||
{!snippet.is_published && (
|
||||
<div className="mt-auto flex items-center gap-1 px-[14px] pt-2 pb-3 text-xs leading-4 text-text-tertiary">
|
||||
<span className="truncate" title={updatedText}>{updatedText}</span>
|
||||
{snippet.is_published && (
|
||||
<>
|
||||
<span>·</span>
|
||||
<span className="truncate">{t('usageCount', { count: snippet.use_count })}</span>
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import type { WorkflowProps } from '@/app/components/workflow'
|
||||
import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store'
|
||||
import type { SnippetDetailPayload } from '@/models/snippet'
|
||||
import type { SnippetDetailPayload, SnippetDetailUIModel, SnippetInputField } from '@/models/snippet'
|
||||
import {
|
||||
useEffect,
|
||||
useMemo,
|
||||
@ -28,6 +28,69 @@ type SnippetMainProps = {
|
||||
snippetId: string
|
||||
} & Pick<WorkflowProps, 'nodes' | 'edges' | 'viewport'>
|
||||
|
||||
type SnippetMainContentProps = {
|
||||
snippetId: string
|
||||
fields: SnippetInputField[]
|
||||
uiMeta: SnippetDetailUIModel
|
||||
editingField: SnippetInputField | null
|
||||
isEditorOpen: boolean
|
||||
isInputPanelOpen: boolean
|
||||
onToggleInputPanel: () => void
|
||||
onCloseInputPanel: () => void
|
||||
onOpenEditor: (field?: SnippetInputField | null) => void
|
||||
onCloseEditor: () => void
|
||||
onSubmitField: (field: SnippetInputField) => void
|
||||
onRemoveField: (index: number) => void
|
||||
onSortChange: (fields: SnippetInputField[]) => void
|
||||
}
|
||||
|
||||
const SnippetMainContent = ({
|
||||
snippetId,
|
||||
fields,
|
||||
uiMeta,
|
||||
editingField,
|
||||
isEditorOpen,
|
||||
isInputPanelOpen,
|
||||
onToggleInputPanel,
|
||||
onCloseInputPanel,
|
||||
onOpenEditor,
|
||||
onCloseEditor,
|
||||
onSubmitField,
|
||||
onRemoveField,
|
||||
onSortChange,
|
||||
}: SnippetMainContentProps) => {
|
||||
const {
|
||||
handlePublish,
|
||||
isPublishMenuOpen,
|
||||
isPublishing,
|
||||
setPublishMenuOpen,
|
||||
} = useSnippetPublish({
|
||||
snippetId,
|
||||
})
|
||||
|
||||
return (
|
||||
<SnippetChildren
|
||||
snippetId={snippetId}
|
||||
fields={fields}
|
||||
uiMeta={uiMeta}
|
||||
editingField={editingField}
|
||||
isEditorOpen={isEditorOpen}
|
||||
isInputPanelOpen={isInputPanelOpen}
|
||||
isPublishMenuOpen={isPublishMenuOpen}
|
||||
isPublishing={isPublishing}
|
||||
onToggleInputPanel={onToggleInputPanel}
|
||||
onPublishMenuOpenChange={setPublishMenuOpen}
|
||||
onCloseInputPanel={onCloseInputPanel}
|
||||
onPublish={handlePublish}
|
||||
onOpenEditor={onOpenEditor}
|
||||
onCloseEditor={onCloseEditor}
|
||||
onSubmitField={onSubmitField}
|
||||
onRemoveField={onRemoveField}
|
||||
onSortChange={onSortChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
const SnippetMain = ({
|
||||
payload,
|
||||
snippetId,
|
||||
@ -109,14 +172,6 @@ const SnippetMain = ({
|
||||
} = useSnippetInputFieldActions({
|
||||
snippetId,
|
||||
})
|
||||
const {
|
||||
handlePublish,
|
||||
isPublishMenuOpen,
|
||||
isPublishing,
|
||||
setPublishMenuOpen,
|
||||
} = useSnippetPublish({
|
||||
snippetId,
|
||||
})
|
||||
const {
|
||||
handleStartWorkflowRun,
|
||||
handleWorkflowStartRunInWorkflow,
|
||||
@ -200,19 +255,15 @@ const SnippetMain = ({
|
||||
viewport={viewport ?? graph.viewport}
|
||||
hooksStore={hooksStore as unknown as Partial<HooksStoreShape>}
|
||||
>
|
||||
<SnippetChildren
|
||||
<SnippetMainContent
|
||||
snippetId={snippetId}
|
||||
fields={fields}
|
||||
uiMeta={uiMeta}
|
||||
editingField={editingField}
|
||||
isEditorOpen={isEditorOpen}
|
||||
isInputPanelOpen={isInputPanelOpen}
|
||||
isPublishMenuOpen={isPublishMenuOpen}
|
||||
isPublishing={isPublishing}
|
||||
onToggleInputPanel={handleToggleInputPanel}
|
||||
onPublishMenuOpenChange={setPublishMenuOpen}
|
||||
onCloseInputPanel={handleCloseInputPanel}
|
||||
onPublish={handlePublish}
|
||||
onOpenEditor={openEditor}
|
||||
onCloseEditor={closeEditor}
|
||||
onSubmitField={handleSubmitField}
|
||||
|
||||
@ -10,6 +10,8 @@ const mockPostWithKeepalive = vi.fn()
|
||||
const mockSyncDraftWorkflow = vi.fn()
|
||||
const mockSetDraftUpdatedAt = vi.fn()
|
||||
const mockSetSyncWorkflowDraftHash = vi.fn()
|
||||
let deferSerialCallbacks = false
|
||||
let queuedSerialCallbacks: Array<() => Promise<void> | void> = []
|
||||
|
||||
let reactFlowState: {
|
||||
getNodes: typeof mockGetNodes
|
||||
@ -37,6 +39,11 @@ vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({
|
||||
if (checkFn?.())
|
||||
return
|
||||
|
||||
if (deferSerialCallbacks) {
|
||||
queuedSerialCallbacks.push(() => fn(...args))
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
return fn(...args)
|
||||
},
|
||||
}))
|
||||
@ -77,6 +84,8 @@ const createInputField = (variable: string): SnippetInputField => ({
|
||||
describe('snippet/use-nodes-sync-draft', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
deferSerialCallbacks = false
|
||||
queuedSerialCallbacks = []
|
||||
reactFlowState = {
|
||||
getNodes: mockGetNodes,
|
||||
edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2', data: { stable: true } }],
|
||||
@ -121,6 +130,38 @@ describe('snippet/use-nodes-sync-draft', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should snapshot graph before queued draft sync executes', async () => {
|
||||
deferSerialCallbacks = true
|
||||
const { result } = renderHook(() => useNodesSyncDraft('snippet-1'))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
mockGetNodes.mockReturnValue([
|
||||
{ id: 'late-node', position: { x: 9, y: 9 }, data: { title: 'Late' } },
|
||||
])
|
||||
reactFlowState.edges = [{ id: 'late-edge', source: 'late-node', target: 'late-target', data: { stable: false } }]
|
||||
reactFlowState.transform = [99, 88, 0.5]
|
||||
|
||||
await act(async () => {
|
||||
await Promise.all(queuedSerialCallbacks.map(run => run()))
|
||||
})
|
||||
|
||||
expect(mockSyncDraftWorkflow).toHaveBeenCalledWith({
|
||||
params: { snippetId: 'snippet-1' },
|
||||
body: {
|
||||
graph: {
|
||||
nodes: [{ id: 'node-1', position: { x: 0, y: 0 }, data: { title: 'Start' } }],
|
||||
edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2', data: { stable: true } }],
|
||||
viewport: { x: 12, y: 24, zoom: 1.5 },
|
||||
},
|
||||
input_fields: [createInputField('topic')],
|
||||
hash: 'draft-hash',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should include the latest graph when syncing input fields', async () => {
|
||||
const { result } = renderHook(() => useNodesSyncDraft('snippet-1'))
|
||||
const nextFields = [createInputField('summary')]
|
||||
|
||||
@ -72,8 +72,9 @@ describe('useSnippetInit', () => {
|
||||
},
|
||||
input_fields: [],
|
||||
created_at: 1_712_300_000,
|
||||
created_by: 'user-1',
|
||||
updated_at: 1_712_300_000,
|
||||
author: 'Evan',
|
||||
updated_by: 'user-1',
|
||||
},
|
||||
error: null,
|
||||
isLoading: false,
|
||||
@ -124,8 +125,9 @@ describe('useSnippetInit', () => {
|
||||
},
|
||||
],
|
||||
created_at: 1_712_300_000,
|
||||
created_by: 'user-1',
|
||||
updated_at: 1_712_300_000,
|
||||
author: 'Evan',
|
||||
updated_by: 'user-1',
|
||||
},
|
||||
error: null,
|
||||
isLoading: false,
|
||||
|
||||
@ -136,21 +136,20 @@ export const useNodesSyncDraft = (snippetId: string) => {
|
||||
}, [getDraftSyncPayload, getNodesReadOnly, snippetId, workflowStore])
|
||||
|
||||
const performSync = useCallback(async (
|
||||
draftPayload: Omit<SnippetDraftSyncPayload, 'hash'> | null,
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncDraftCallback,
|
||||
) => {
|
||||
const draftPayload = getDraftSyncPayload()
|
||||
if (!draftPayload)
|
||||
return
|
||||
|
||||
await syncDraft(draftPayload, notRefreshWhenSyncError, callback)
|
||||
}, [getDraftSyncPayload, syncDraft])
|
||||
}, [syncDraft])
|
||||
|
||||
const performInputFieldsSync = useCallback(async (
|
||||
inputFields: SnippetInputField[],
|
||||
draftPayload: Omit<SnippetDraftSyncPayload, 'hash'> | null,
|
||||
callback?: SyncInputFieldsDraftCallback,
|
||||
) => {
|
||||
const draftPayload = getDraftSyncPayload(inputFields)
|
||||
if (!draftPayload)
|
||||
return
|
||||
|
||||
@ -165,10 +164,29 @@ export const useNodesSyncDraft = (snippetId: string) => {
|
||||
callback?.onRefresh?.(refreshedInputFields)
|
||||
},
|
||||
)
|
||||
}, [getDraftSyncPayload, syncDraft])
|
||||
}, [syncDraft])
|
||||
|
||||
const doSyncWorkflowDraft = useSerialAsyncCallback(performSync, getNodesReadOnly)
|
||||
const syncInputFieldsDraft = useSerialAsyncCallback(performInputFieldsSync)
|
||||
const syncWorkflowDraftWithPayload = useSerialAsyncCallback(performSync, getNodesReadOnly)
|
||||
const syncInputFieldsDraftWithPayload = useSerialAsyncCallback(performInputFieldsSync)
|
||||
|
||||
const doSyncWorkflowDraft = useCallback((
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncDraftCallback,
|
||||
) => {
|
||||
if (getNodesReadOnly())
|
||||
return Promise.resolve()
|
||||
|
||||
const draftPayload = getDraftSyncPayload()
|
||||
return syncWorkflowDraftWithPayload(draftPayload, notRefreshWhenSyncError, callback)
|
||||
}, [getDraftSyncPayload, getNodesReadOnly, syncWorkflowDraftWithPayload])
|
||||
|
||||
const syncInputFieldsDraft = useCallback((
|
||||
inputFields: SnippetInputField[],
|
||||
callback?: SyncInputFieldsDraftCallback,
|
||||
) => {
|
||||
const draftPayload = getDraftSyncPayload(inputFields)
|
||||
return syncInputFieldsDraftWithPayload(draftPayload, callback)
|
||||
}, [getDraftSyncPayload, syncInputFieldsDraftWithPayload])
|
||||
|
||||
return {
|
||||
doSyncWorkflowDraft,
|
||||
|
||||
@ -10,6 +10,7 @@ import type {
|
||||
import type {
|
||||
CommonNodeType,
|
||||
NodeDefault,
|
||||
OnNodeAdd,
|
||||
OnSelectBlock,
|
||||
ToolWithProvider,
|
||||
} from '../types'
|
||||
@ -65,6 +66,7 @@ export type NodeSelectorProps = {
|
||||
ignoreNodeIds?: string[]
|
||||
forceEnableStartTab?: boolean // Force enabling Start tab regardless of existing trigger/user input nodes (e.g., when changing Start node type).
|
||||
allowUserInputSelection?: boolean // Override user-input availability; default logic blocks it when triggers exist.
|
||||
snippetInsertPayload?: Parameters<OnNodeAdd>[1]
|
||||
}
|
||||
const NodeSelector: FC<NodeSelectorProps> = ({
|
||||
open: openFromProps,
|
||||
@ -90,6 +92,7 @@ const NodeSelector: FC<NodeSelectorProps> = ({
|
||||
ignoreNodeIds = [],
|
||||
forceEnableStartTab = false,
|
||||
allowUserInputSelection,
|
||||
snippetInsertPayload,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const nodes = useNodes()
|
||||
@ -335,7 +338,14 @@ const NodeSelector: FC<NodeSelectorProps> = ({
|
||||
noTools={noTools}
|
||||
onTagsChange={setTags}
|
||||
forceShowStartContent={forceShowStartContent}
|
||||
snippetsElem={<Snippets loading={snippetsLoading} searchText={searchText} />}
|
||||
snippetsElem={(
|
||||
<Snippets
|
||||
loading={snippetsLoading}
|
||||
searchText={searchText}
|
||||
insertPayload={snippetInsertPayload}
|
||||
onInserted={() => handleOpenChange(false)}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
|
||||
@ -74,7 +74,6 @@ describe('Snippets', () => {
|
||||
id: 'snippet-1',
|
||||
name: 'Customer Review',
|
||||
description: 'Snippet description',
|
||||
author: 'Evan',
|
||||
type: 'group',
|
||||
is_published: true,
|
||||
version: '1.0.0',
|
||||
@ -87,7 +86,9 @@ describe('Snippets', () => {
|
||||
},
|
||||
input_fields: [],
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 2,
|
||||
updated_by: 'user-1',
|
||||
}],
|
||||
}],
|
||||
},
|
||||
@ -127,7 +128,6 @@ describe('Snippets', () => {
|
||||
id: 'snippet-1',
|
||||
name: 'Customer Review',
|
||||
description: 'Snippet description',
|
||||
author: 'Evan',
|
||||
type: 'group',
|
||||
is_published: true,
|
||||
version: '1.0.0',
|
||||
@ -140,7 +140,9 @@ describe('Snippets', () => {
|
||||
},
|
||||
input_fields: [],
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 2,
|
||||
updated_by: 'user-1',
|
||||
}],
|
||||
}],
|
||||
},
|
||||
@ -155,7 +157,7 @@ describe('Snippets', () => {
|
||||
|
||||
fireEvent.click(screen.getByText('Customer Review'))
|
||||
|
||||
expect(mockHandleInsertSnippet).toHaveBeenCalledWith('snippet-1')
|
||||
expect(mockHandleInsertSnippet).toHaveBeenCalledWith('snippet-1', undefined)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -8,11 +8,20 @@ vi.mock('@/service/use-snippet-workflows', () => ({
|
||||
useSnippetPublishedWorkflow: (...args: unknown[]) => mockUseSnippetPublishedWorkflow(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useMembers: () => ({
|
||||
data: {
|
||||
accounts: [
|
||||
{ id: 'user-1', name: 'Evan', email: 'evan@example.com', avatar: '', avatar_url: null, role: 'editor', last_login_at: '', created_at: '', status: 'active' },
|
||||
],
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
const createSnippet = (overrides: Partial<PublishedSnippetListItem> = {}): PublishedSnippetListItem => ({
|
||||
id: 'snippet-1',
|
||||
name: 'Customer Review',
|
||||
description: 'Snippet description',
|
||||
author: 'Evan',
|
||||
type: 'group',
|
||||
is_published: true,
|
||||
use_count: 3,
|
||||
@ -23,7 +32,9 @@ const createSnippet = (overrides: Partial<PublishedSnippetListItem> = {}): Publi
|
||||
icon_url: '',
|
||||
},
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 2,
|
||||
updated_by: 'user-1',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ const createSnippet = (overrides: Partial<PublishedSnippetListItem> = {}): Publi
|
||||
id: 'snippet-1',
|
||||
name: 'Customer Review',
|
||||
description: 'Snippet description',
|
||||
author: 'Evan',
|
||||
type: 'group',
|
||||
is_published: true,
|
||||
use_count: 3,
|
||||
@ -17,7 +16,9 @@ const createSnippet = (overrides: Partial<PublishedSnippetListItem> = {}): Publi
|
||||
icon_url: '',
|
||||
},
|
||||
created_at: 1,
|
||||
created_by: 'user-1',
|
||||
updated_at: 2,
|
||||
updated_by: 'user-1',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
@ -38,10 +39,9 @@ describe('SnippetListItem', () => {
|
||||
)
|
||||
|
||||
expect(screen.getByText('Customer Review')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Evan')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render author when hovered', () => {
|
||||
it('should not render metadata when hovered', () => {
|
||||
render(
|
||||
<SnippetListItem
|
||||
snippet={createSnippet()}
|
||||
@ -51,7 +51,7 @@ describe('SnippetListItem', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Evan')).toBeInTheDocument()
|
||||
expect(screen.getByText('Customer Review')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -1,6 +1,27 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useInsertSnippet } from '../use-insert-snippet'
|
||||
|
||||
type TestNode = {
|
||||
id: string
|
||||
position: { x: number, y: number }
|
||||
selected?: boolean
|
||||
parentId?: string
|
||||
data: {
|
||||
selected?: boolean
|
||||
_children?: { nodeId: string, nodeType: string }[]
|
||||
_connectedSourceHandleIds?: string[]
|
||||
_connectedTargetHandleIds?: string[]
|
||||
}
|
||||
}
|
||||
|
||||
type TestEdge = {
|
||||
id: string
|
||||
source: string
|
||||
sourceHandle?: string
|
||||
target: string
|
||||
targetHandle?: string
|
||||
}
|
||||
|
||||
const mockFetchQuery = vi.fn()
|
||||
const mockHandleSyncWorkflowDraft = vi.fn()
|
||||
const mockSaveStateToHistory = vi.fn()
|
||||
@ -8,6 +29,8 @@ const mockToastError = vi.fn()
|
||||
const mockGetNodes = vi.fn()
|
||||
const mockSetNodes = vi.fn()
|
||||
const mockSetEdges = vi.fn()
|
||||
const mockIncrementSnippetUseCount = vi.fn()
|
||||
let mockEdges: unknown[] = [{ id: 'existing-edge', source: 'old', target: 'old-2' }]
|
||||
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQueryClient: () => ({
|
||||
@ -20,7 +43,7 @@ vi.mock('reactflow', () => ({
|
||||
getState: () => ({
|
||||
getNodes: mockGetNodes,
|
||||
setNodes: mockSetNodes,
|
||||
edges: [{ id: 'existing-edge', source: 'old', target: 'old-2' }],
|
||||
edges: mockEdges,
|
||||
setEdges: mockSetEdges,
|
||||
}),
|
||||
}),
|
||||
@ -44,9 +67,16 @@ vi.mock('@langgenius/dify-ui/toast', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-snippets', () => ({
|
||||
useIncrementSnippetUseCountMutation: () => ({
|
||||
mutate: mockIncrementSnippetUseCount,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('useInsertSnippet', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockEdges = [{ id: 'existing-edge', source: 'old', target: 'old-2' }]
|
||||
mockGetNodes.mockReturnValue([
|
||||
{
|
||||
id: 'existing-node',
|
||||
@ -96,23 +126,131 @@ describe('useInsertSnippet', () => {
|
||||
expect(mockSetNodes).toHaveBeenCalledTimes(1)
|
||||
expect(mockSetEdges).toHaveBeenCalledTimes(1)
|
||||
|
||||
const nextNodes = mockSetNodes.mock.calls[0][0]
|
||||
expect(nextNodes[0].selected).toBe(false)
|
||||
expect(nextNodes[0].data.selected).toBe(false)
|
||||
const nextNodes = mockSetNodes.mock.calls[0]![0] as TestNode[]
|
||||
expect(nextNodes[0]!.selected).toBe(false)
|
||||
expect(nextNodes[0]!.data.selected).toBe(false)
|
||||
expect(nextNodes).toHaveLength(3)
|
||||
expect(nextNodes[1].id).not.toBe('snippet-node-1')
|
||||
expect(nextNodes[2].parentId).toBe(nextNodes[1].id)
|
||||
expect(nextNodes[1].data._children[0].nodeId).toBe(nextNodes[2].id)
|
||||
expect(nextNodes[1]!.id).not.toBe('snippet-node-1')
|
||||
expect(nextNodes[2]!.parentId).toBe(nextNodes[1]!.id)
|
||||
expect(nextNodes[1]!.data._children![0]!.nodeId).toBe(nextNodes[2]!.id)
|
||||
|
||||
const nextEdges = mockSetEdges.mock.calls[0][0]
|
||||
const nextEdges = mockSetEdges.mock.calls[0]![0] as TestEdge[]
|
||||
expect(nextEdges).toHaveLength(2)
|
||||
expect(nextEdges[1].source).toBe(nextNodes[1].id)
|
||||
expect(nextEdges[1].target).toBe(nextNodes[2].id)
|
||||
expect(nextEdges[1]!.source).toBe(nextNodes[1]!.id)
|
||||
expect(nextEdges[1]!.target).toBe(nextNodes[2]!.id)
|
||||
|
||||
expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodePaste', {
|
||||
nodeId: nextNodes[1].id,
|
||||
nodeId: nextNodes[1]!.id,
|
||||
})
|
||||
expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1)
|
||||
expect(mockIncrementSnippetUseCount).toHaveBeenCalledWith({
|
||||
params: { snippetId: 'snippet-1' },
|
||||
})
|
||||
})
|
||||
|
||||
it('should connect inserted snippet nodes to the requested edge position', async () => {
|
||||
mockGetNodes.mockReturnValue([
|
||||
{
|
||||
id: 'prev-node',
|
||||
position: { x: 0, y: 0 },
|
||||
width: 240,
|
||||
data: { type: 'start', selected: true, _connectedSourceHandleIds: ['source'] },
|
||||
},
|
||||
{
|
||||
id: 'next-node',
|
||||
position: { x: 300, y: 0 },
|
||||
data: { type: 'answer', selected: false, _connectedTargetHandleIds: ['target'] },
|
||||
},
|
||||
])
|
||||
mockEdges = [
|
||||
{
|
||||
id: 'prev-node-source-next-node-target',
|
||||
source: 'prev-node',
|
||||
sourceHandle: 'source',
|
||||
target: 'next-node',
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
sourceType: 'start',
|
||||
targetType: 'answer',
|
||||
},
|
||||
},
|
||||
]
|
||||
mockFetchQuery.mockResolvedValue({
|
||||
graph: {
|
||||
nodes: [
|
||||
{
|
||||
id: 'snippet-entry',
|
||||
position: { x: 0, y: 0 },
|
||||
data: { type: 'llm', selected: false },
|
||||
},
|
||||
{
|
||||
id: 'snippet-exit',
|
||||
position: { x: 300, y: 0 },
|
||||
data: { type: 'code', selected: false },
|
||||
},
|
||||
],
|
||||
edges: [
|
||||
{
|
||||
id: 'snippet-entry-source-snippet-exit-target',
|
||||
source: 'snippet-entry',
|
||||
sourceHandle: 'source',
|
||||
target: 'snippet-exit',
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
sourceType: 'llm',
|
||||
targetType: 'code',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useInsertSnippet())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleInsertSnippet('snippet-1', {
|
||||
prevNodeId: 'prev-node',
|
||||
prevNodeSourceHandle: 'source',
|
||||
nextNodeId: 'next-node',
|
||||
nextNodeTargetHandle: 'target',
|
||||
})
|
||||
})
|
||||
|
||||
const nextNodes = mockSetNodes.mock.calls[0]![0] as TestNode[]
|
||||
const insertedEntry = nextNodes.find(node => node.id !== 'prev-node' && node.id !== 'next-node' && node.id.includes('snippet-entry'))!
|
||||
const insertedExit = nextNodes.find(node => node.id !== 'prev-node' && node.id !== 'next-node' && node.id.includes('snippet-exit'))!
|
||||
const shiftedNextNode = nextNodes.find(node => node.id === 'next-node')!
|
||||
expect(insertedEntry.position).toEqual({ x: 300, y: 0 })
|
||||
expect(shiftedNextNode.position.x).toBe(600)
|
||||
expect(nextNodes.find(node => node.id === 'prev-node')!.data._connectedSourceHandleIds).toEqual(['source'])
|
||||
expect(insertedEntry.data._connectedTargetHandleIds).toEqual(['target'])
|
||||
expect(insertedExit.data._connectedSourceHandleIds).toEqual(['source'])
|
||||
expect(shiftedNextNode.data._connectedTargetHandleIds).toEqual(['target'])
|
||||
|
||||
const nextEdges = mockSetEdges.mock.calls[0]![0] as TestEdge[]
|
||||
expect(nextEdges).toHaveLength(3)
|
||||
expect(nextEdges.some(edge => edge.id === 'prev-node-source-next-node-target')).toBe(false)
|
||||
expect(nextEdges).toEqual(expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
source: 'prev-node',
|
||||
sourceHandle: 'source',
|
||||
target: insertedEntry.id,
|
||||
targetHandle: 'target',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
source: insertedEntry.id,
|
||||
target: insertedExit.id,
|
||||
}),
|
||||
expect.objectContaining({
|
||||
source: insertedExit.id,
|
||||
sourceHandle: 'source',
|
||||
target: 'next-node',
|
||||
targetHandle: 'target',
|
||||
}),
|
||||
]))
|
||||
expect(mockIncrementSnippetUseCount).toHaveBeenCalledWith({
|
||||
params: { snippetId: 'snippet-1' },
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error toast when fetching snippet workflow fails', async () => {
|
||||
@ -125,6 +263,7 @@ describe('useInsertSnippet', () => {
|
||||
})
|
||||
|
||||
expect(mockToastError).toHaveBeenCalledWith('insert failed')
|
||||
expect(mockIncrementSnippetUseCount).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { OnNodeAdd } from '../../types'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import {
|
||||
ScrollAreaContent,
|
||||
@ -14,6 +15,7 @@ import {
|
||||
import { useInfiniteScroll } from 'ahooks'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useDeferredValue,
|
||||
useMemo,
|
||||
useRef,
|
||||
@ -31,6 +33,8 @@ import { useInsertSnippet } from './use-insert-snippet'
|
||||
type SnippetsProps = {
|
||||
loading?: boolean
|
||||
searchText: string
|
||||
insertPayload?: Parameters<OnNodeAdd>[1]
|
||||
onInserted?: () => void
|
||||
}
|
||||
|
||||
const LoadingSkeleton = () => {
|
||||
@ -60,6 +64,8 @@ const LoadingSkeleton = () => {
|
||||
const Snippets = ({
|
||||
loading = false,
|
||||
searchText,
|
||||
insertPayload,
|
||||
onInserted,
|
||||
}: SnippetsProps) => {
|
||||
const {
|
||||
createSnippetMutation,
|
||||
@ -95,6 +101,11 @@ const Snippets = ({
|
||||
}, [data?.pages])
|
||||
|
||||
const isNoMore = hasNextPage === false
|
||||
const handleSnippetClick = useCallback(async (snippetId: string) => {
|
||||
const inserted = await handleInsertSnippet(snippetId, insertPayload)
|
||||
if (inserted)
|
||||
onInserted?.()
|
||||
}, [handleInsertSnippet, insertPayload, onInserted])
|
||||
|
||||
useInfiniteScroll(
|
||||
async () => {
|
||||
@ -129,7 +140,7 @@ const Snippets = ({
|
||||
<SnippetListItem
|
||||
snippet={item}
|
||||
isHovered={hoveredSnippetId === item.id}
|
||||
onClick={() => handleInsertSnippet(item.id)}
|
||||
onClick={() => handleSnippetClick(item.id)}
|
||||
onMouseEnter={() => setHoveredSnippetId(item.id)}
|
||||
onMouseLeave={() => setHoveredSnippetId(current => current === item.id ? null : current)}
|
||||
/>
|
||||
@ -146,7 +157,6 @@ const Snippets = ({
|
||||
/>
|
||||
<TooltipContent
|
||||
placement="left-start"
|
||||
variant="plain"
|
||||
className="bg-transparent! p-0!"
|
||||
>
|
||||
<SnippetDetailCard snippet={item} />
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import type { FC } from 'react'
|
||||
import type { SnippetListItem } from '@/types/snippet'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { useSnippetPublishedWorkflow } from '@/service/use-snippet-workflows'
|
||||
import BlockIcon from '../../block-icon'
|
||||
import { BlockEnum } from '../../types'
|
||||
@ -15,9 +17,16 @@ type SnippetDetailCardProps = {
|
||||
const SnippetDetailCard: FC<SnippetDetailCardProps> = ({
|
||||
snippet,
|
||||
}) => {
|
||||
const { author, description, icon_info, name } = snippet
|
||||
const { description, icon_info, name } = snippet
|
||||
const { t } = useTranslation('snippet')
|
||||
const { data: membersData } = useMembers()
|
||||
const { data: workflow } = useSnippetPublishedWorkflow(snippet.id)
|
||||
|
||||
const creatorName = useMemo(() => {
|
||||
const member = membersData?.accounts?.find(member => member.id === snippet.created_by)
|
||||
return member?.name || t('unknownUser')
|
||||
}, [membersData?.accounts, snippet.created_by, t])
|
||||
|
||||
const blockTypes = useMemo(() => {
|
||||
const graph = workflow?.graph
|
||||
if (!graph || typeof graph !== 'object')
|
||||
@ -51,7 +60,7 @@ const SnippetDetailCard: FC<SnippetDetailCardProps> = ({
|
||||
}, [workflow?.graph])
|
||||
|
||||
return (
|
||||
<div className="w-[224px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-3 pb-4 pt-3 shadow-lg backdrop-blur-[5px]">
|
||||
<div className="w-[224px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-3 pt-3 pb-4 shadow-lg backdrop-blur-[5px]">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<AppIcon
|
||||
@ -61,10 +70,10 @@ const SnippetDetailCard: FC<SnippetDetailCardProps> = ({
|
||||
background={icon_info.icon_background}
|
||||
imageUrl={icon_info.icon_url}
|
||||
/>
|
||||
<div className="text-text-primary system-md-medium">{name}</div>
|
||||
<div className="system-md-medium text-text-primary">{name}</div>
|
||||
</div>
|
||||
{!!description && (
|
||||
<div className="w-[200px] text-text-secondary system-xs-regular">
|
||||
<div className="w-[200px] system-xs-regular text-text-secondary">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
@ -80,11 +89,9 @@ const SnippetDetailCard: FC<SnippetDetailCardProps> = ({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!!author && (
|
||||
<div className="pt-3 text-text-tertiary system-xs-regular">
|
||||
{author}
|
||||
</div>
|
||||
)}
|
||||
<div className="pt-3 system-xs-regular text-text-tertiary">
|
||||
{creatorName}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -36,14 +36,9 @@ const SnippetListItem = ({
|
||||
background={snippet.icon_info.icon_background}
|
||||
imageUrl={snippet.icon_info.icon_url}
|
||||
/>
|
||||
<div className="system-sm-medium min-w-0 text-text-secondary">
|
||||
<div className="min-w-0 system-sm-medium text-text-secondary">
|
||||
{snippet.name}
|
||||
</div>
|
||||
{isHovered && snippet.author && (
|
||||
<div className="system-xs-regular ml-auto text-text-tertiary">
|
||||
{snippet.author}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
import type { Edge, Node } from '../../types'
|
||||
import type { Edge, Node, OnNodeAdd } from '../../types'
|
||||
import { toast } from '@langgenius/dify-ui/toast'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { useIncrementSnippetUseCountMutation } from '@/service/use-snippets'
|
||||
import { CUSTOM_EDGE, ITERATION_CHILDREN_Z_INDEX, LOOP_CHILDREN_Z_INDEX, NODE_WIDTH_X_OFFSET, X_OFFSET } from '../../constants'
|
||||
import { useNodesSyncDraft, useWorkflowHistory, WorkflowHistoryEvent } from '../../hooks'
|
||||
import { BlockEnum } from '../../types'
|
||||
import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../../utils'
|
||||
|
||||
type SnippetInsertPayload = Parameters<OnNodeAdd>[1]
|
||||
|
||||
const getSnippetGraph = (graph: Record<string, unknown> | undefined) => {
|
||||
if (!graph)
|
||||
@ -17,7 +23,74 @@ const getSnippetGraph = (graph: Record<string, unknown> | undefined) => {
|
||||
}
|
||||
}
|
||||
|
||||
const remapSnippetGraph = (currentNodes: Node[], snippetNodes: Node[], snippetEdges: Edge[]) => {
|
||||
const getRootNodes = (nodes: Node[]) => {
|
||||
const rootNodes = nodes.filter(node => !node.parentId)
|
||||
return rootNodes.length ? rootNodes : nodes
|
||||
}
|
||||
|
||||
const getSnippetBoundaryNodes = (nodes: Node[], edges: Edge[]) => {
|
||||
const rootNodes = getRootNodes(nodes)
|
||||
const rootNodeIds = new Set(rootNodes.map(node => node.id))
|
||||
const incomingNodeIds = new Set<string>()
|
||||
const outgoingNodeIds = new Set<string>()
|
||||
|
||||
edges.forEach((edge) => {
|
||||
if (!rootNodeIds.has(edge.source) || !rootNodeIds.has(edge.target))
|
||||
return
|
||||
|
||||
outgoingNodeIds.add(edge.source)
|
||||
incomingNodeIds.add(edge.target)
|
||||
})
|
||||
|
||||
return {
|
||||
entryNodes: rootNodes.filter(node => !incomingNodeIds.has(node.id)),
|
||||
exitNodes: rootNodes.filter(node => !outgoingNodeIds.has(node.id)),
|
||||
}
|
||||
}
|
||||
|
||||
const canConnectToTarget = (node: Node) => {
|
||||
return node.data.type !== BlockEnum.DataSource
|
||||
}
|
||||
|
||||
const canConnectFromSource = (node: Node) => {
|
||||
return node.data.type !== BlockEnum.IfElse
|
||||
&& node.data.type !== BlockEnum.QuestionClassifier
|
||||
&& node.data.type !== BlockEnum.HumanInput
|
||||
&& node.data.type !== BlockEnum.LoopEnd
|
||||
}
|
||||
|
||||
const getInsertAnchor = (
|
||||
currentNodes: Node[],
|
||||
insertPayload?: SnippetInsertPayload,
|
||||
) => {
|
||||
const prevNode = insertPayload?.prevNodeId
|
||||
? currentNodes.find(node => node.id === insertPayload.prevNodeId)
|
||||
: undefined
|
||||
const nextNode = insertPayload?.nextNodeId
|
||||
? currentNodes.find(node => node.id === insertPayload.nextNodeId)
|
||||
: undefined
|
||||
|
||||
if (nextNode) {
|
||||
return {
|
||||
x: nextNode.position.x,
|
||||
y: nextNode.position.y,
|
||||
}
|
||||
}
|
||||
|
||||
if (prevNode) {
|
||||
return {
|
||||
x: prevNode.position.x + (prevNode.width ?? 0) + X_OFFSET,
|
||||
y: prevNode.position.y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const remapSnippetGraph = (
|
||||
currentNodes: Node[],
|
||||
snippetNodes: Node[],
|
||||
snippetEdges: Edge[],
|
||||
insertPayload?: SnippetInsertPayload,
|
||||
) => {
|
||||
const existingIds = new Set(currentNodes.map(node => node.id))
|
||||
const idMapping = new Map<string, string>()
|
||||
const rootNodes = snippetNodes.filter(node => !node.parentId)
|
||||
@ -32,8 +105,9 @@ const remapSnippetGraph = (currentNodes: Node[], snippetNodes: Node[], snippetEd
|
||||
const currentMinY = currentNodes.length
|
||||
? Math.min(...currentNodes.map(node => node.positionAbsolute?.y ?? node.position.y))
|
||||
: 0
|
||||
const offsetX = (currentNodes.length ? currentMaxX + 80 : 80) - minRootX
|
||||
const offsetY = (currentNodes.length ? currentMinY : 80) - minRootY
|
||||
const insertAnchor = getInsertAnchor(currentNodes, insertPayload)
|
||||
const offsetX = (insertAnchor?.x ?? (currentNodes.length ? currentMaxX + 80 : 80)) - minRootX
|
||||
const offsetY = (insertAnchor?.y ?? (currentNodes.length ? currentMinY : 80)) - minRootY
|
||||
|
||||
snippetNodes.forEach((node, index) => {
|
||||
let nextId = `${node.id}-${Date.now()}-${index}`
|
||||
@ -94,14 +168,120 @@ const remapSnippetGraph = (currentNodes: Node[], snippetNodes: Node[], snippetEd
|
||||
return { nodes, edges }
|
||||
}
|
||||
|
||||
const getCurrentEdge = (edges: Edge[], insertPayload?: SnippetInsertPayload) => {
|
||||
if (!insertPayload?.prevNodeId || !insertPayload.nextNodeId)
|
||||
return undefined
|
||||
|
||||
return edges.find(edge =>
|
||||
edge.source === insertPayload.prevNodeId
|
||||
&& edge.target === insertPayload.nextNodeId
|
||||
&& (edge.sourceHandle || 'source') === (insertPayload.prevNodeSourceHandle || 'source')
|
||||
&& (edge.targetHandle || 'target') === (insertPayload.nextNodeTargetHandle || 'target'),
|
||||
)
|
||||
}
|
||||
|
||||
const getParentNode = (nodes: Node[], insertPayload?: SnippetInsertPayload) => {
|
||||
const prevNode = insertPayload?.prevNodeId
|
||||
? nodes.find(node => node.id === insertPayload.prevNodeId)
|
||||
: undefined
|
||||
const nextNode = insertPayload?.nextNodeId
|
||||
? nodes.find(node => node.id === insertPayload.nextNodeId)
|
||||
: undefined
|
||||
const parentId = prevNode?.parentId ?? nextNode?.parentId
|
||||
|
||||
return parentId ? nodes.find(node => node.id === parentId) : undefined
|
||||
}
|
||||
|
||||
const createBoundaryEdges = ({
|
||||
currentNodes,
|
||||
insertPayload,
|
||||
entryNodes,
|
||||
exitNodes,
|
||||
}: {
|
||||
currentNodes: Node[]
|
||||
insertPayload?: SnippetInsertPayload
|
||||
entryNodes: Node[]
|
||||
exitNodes: Node[]
|
||||
}) => {
|
||||
const prevNode = insertPayload?.prevNodeId
|
||||
? currentNodes.find(node => node.id === insertPayload.prevNodeId)
|
||||
: undefined
|
||||
const nextNode = insertPayload?.nextNodeId
|
||||
? currentNodes.find(node => node.id === insertPayload.nextNodeId)
|
||||
: undefined
|
||||
const parentNode = getParentNode(currentNodes, insertPayload)
|
||||
const isInIteration = parentNode?.data.type === BlockEnum.Iteration
|
||||
const isInLoop = parentNode?.data.type === BlockEnum.Loop
|
||||
const zIndex = parentNode
|
||||
? isInIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
|
||||
: 0
|
||||
const incomingEdges: Edge[] = []
|
||||
const outgoingEdges: Edge[] = []
|
||||
|
||||
if (prevNode) {
|
||||
incomingEdges.push(...entryNodes.filter(canConnectToTarget).map((entryNode) => {
|
||||
const sourceHandle = insertPayload?.prevNodeSourceHandle || 'source'
|
||||
const targetHandle = 'target'
|
||||
|
||||
return {
|
||||
id: `${prevNode.id}-${sourceHandle}-${entryNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: prevNode.id,
|
||||
sourceHandle,
|
||||
target: entryNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: prevNode.data.type,
|
||||
targetType: entryNode.data.type,
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? parentNode?.id : undefined,
|
||||
loop_id: isInLoop ? parentNode?.id : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex,
|
||||
} as Edge
|
||||
}))
|
||||
}
|
||||
|
||||
if (nextNode) {
|
||||
outgoingEdges.push(...exitNodes.filter(canConnectFromSource).map((exitNode) => {
|
||||
const sourceHandle = 'source'
|
||||
const targetHandle = insertPayload?.nextNodeTargetHandle || 'target'
|
||||
|
||||
return {
|
||||
id: `${exitNode.id}-${sourceHandle}-${nextNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: exitNode.id,
|
||||
sourceHandle,
|
||||
target: nextNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: exitNode.data.type,
|
||||
targetType: nextNode.data.type,
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? parentNode?.id : undefined,
|
||||
loop_id: isInLoop ? parentNode?.id : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex,
|
||||
} as Edge
|
||||
}))
|
||||
}
|
||||
|
||||
return [...incomingEdges, ...outgoingEdges]
|
||||
}
|
||||
|
||||
export const useInsertSnippet = () => {
|
||||
const { t } = useTranslation()
|
||||
const queryClient = useQueryClient()
|
||||
const store = useStoreApi()
|
||||
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
const { saveStateToHistory } = useWorkflowHistory()
|
||||
const { mutate: incrementSnippetUseCount } = useIncrementSnippetUseCountMutation()
|
||||
|
||||
const handleInsertSnippet = useCallback(async (snippetId: string) => {
|
||||
const handleInsertSnippet = useCallback(async (snippetId: string, insertPayload?: SnippetInsertPayload) => {
|
||||
try {
|
||||
const workflow = await queryClient.fetchQuery(consoleQuery.snippets.publishedWorkflow.queryOptions({
|
||||
input: {
|
||||
@ -115,27 +295,112 @@ export const useInsertSnippet = () => {
|
||||
|
||||
const { getNodes, setNodes, edges, setEdges } = store.getState()
|
||||
const currentNodes = getNodes()
|
||||
const remappedGraph = remapSnippetGraph(currentNodes, snippetNodes, snippetEdges)
|
||||
const remappedGraph = remapSnippetGraph(currentNodes, snippetNodes, snippetEdges, insertPayload)
|
||||
const parentNode = getParentNode(currentNodes, insertPayload)
|
||||
const rootNodeIds = new Set(getRootNodes(remappedGraph.nodes).map(node => node.id))
|
||||
const rootSnippetNodes = remappedGraph.nodes.filter(node => rootNodeIds.has(node.id))
|
||||
const currentEdge = getCurrentEdge(edges, insertPayload)
|
||||
const { entryNodes, exitNodes } = getSnippetBoundaryNodes(remappedGraph.nodes, remappedGraph.edges)
|
||||
const boundaryEdges = createBoundaryEdges({
|
||||
currentNodes,
|
||||
insertPayload,
|
||||
entryNodes,
|
||||
exitNodes,
|
||||
})
|
||||
const changes = [
|
||||
...(currentEdge ? [{ type: 'remove', edge: currentEdge }] : []),
|
||||
...boundaryEdges.map(edge => ({ type: 'add', edge })),
|
||||
]
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
changes,
|
||||
[...currentNodes, ...remappedGraph.nodes],
|
||||
)
|
||||
const firstEntryNode = entryNodes.find(canConnectToTarget) ?? entryNodes[0]
|
||||
const clearedNodes = currentNodes.map(node => ({
|
||||
...node,
|
||||
selected: false,
|
||||
position: insertPayload?.nextNodeId && node.id === insertPayload.nextNodeId
|
||||
? {
|
||||
...node.position,
|
||||
x: node.position.x + NODE_WIDTH_X_OFFSET,
|
||||
}
|
||||
: node.position,
|
||||
data: {
|
||||
...node.data,
|
||||
selected: false,
|
||||
...(nodesConnectedSourceOrTargetHandleIdsMap[node.id] ?? {}),
|
||||
_children: parentNode?.id === node.id
|
||||
? [
|
||||
...(node.data._children ?? []),
|
||||
...rootSnippetNodes.map(rootNode => ({
|
||||
nodeId: rootNode.id,
|
||||
nodeType: rootNode.data.type,
|
||||
})),
|
||||
]
|
||||
: node.data._children,
|
||||
start_node_id: node.id === parentNode?.id
|
||||
&& node.data.start_node_id === insertPayload?.nextNodeId
|
||||
&& firstEntryNode
|
||||
? firstEntryNode.id
|
||||
: node.data.start_node_id,
|
||||
startNodeType: node.id === parentNode?.id
|
||||
&& node.data.start_node_id === insertPayload?.nextNodeId
|
||||
&& firstEntryNode
|
||||
? firstEntryNode.data.type
|
||||
: node.data.startNodeType,
|
||||
},
|
||||
}))
|
||||
const insertedNodes = remappedGraph.nodes.map((node) => {
|
||||
const shouldMoveIntoParent = !!parentNode && rootNodeIds.has(node.id)
|
||||
const isInIteration = parentNode?.data.type === BlockEnum.Iteration
|
||||
const isInLoop = parentNode?.data.type === BlockEnum.Loop
|
||||
|
||||
setNodes([...clearedNodes, ...remappedGraph.nodes])
|
||||
setEdges([...edges, ...remappedGraph.edges])
|
||||
return {
|
||||
...node,
|
||||
parentId: shouldMoveIntoParent ? parentNode.id : node.parentId,
|
||||
extent: shouldMoveIntoParent ? parentNode.extent : node.extent,
|
||||
zIndex: shouldMoveIntoParent
|
||||
? isInIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
|
||||
: node.zIndex,
|
||||
data: {
|
||||
...node.data,
|
||||
...(nodesConnectedSourceOrTargetHandleIdsMap[node.id] ?? {}),
|
||||
isInIteration: shouldMoveIntoParent ? isInIteration : node.data.isInIteration,
|
||||
isInLoop: shouldMoveIntoParent ? isInLoop : node.data.isInLoop,
|
||||
iteration_id: shouldMoveIntoParent && isInIteration ? parentNode.id : node.data.iteration_id,
|
||||
loop_id: shouldMoveIntoParent && isInLoop ? parentNode.id : node.data.loop_id,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
setNodes([...clearedNodes, ...insertedNodes])
|
||||
setEdges([
|
||||
...edges
|
||||
.filter(edge => edge.id !== currentEdge?.id)
|
||||
.map(edge => ({
|
||||
...edge,
|
||||
data: {
|
||||
...edge.data,
|
||||
_connectedNodeIsSelected: false,
|
||||
},
|
||||
})),
|
||||
...remappedGraph.edges,
|
||||
...boundaryEdges,
|
||||
])
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
|
||||
nodeId: remappedGraph.nodes[0]?.id,
|
||||
})
|
||||
handleSyncWorkflowDraft()
|
||||
incrementSnippetUseCount({
|
||||
params: { snippetId },
|
||||
})
|
||||
return true
|
||||
}
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : t('createFailed', { ns: 'snippet' }))
|
||||
return false
|
||||
}
|
||||
}, [handleSyncWorkflowDraft, queryClient, saveStateToHistory, store, t])
|
||||
}, [handleSyncWorkflowDraft, incrementSnippetUseCount, queryClient, saveStateToHistory, store, t])
|
||||
|
||||
return {
|
||||
handleInsertSnippet,
|
||||
|
||||
@ -163,6 +163,12 @@ const CustomEdge = ({
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
snippetInsertPayload={{
|
||||
prevNodeId: source,
|
||||
prevNodeSourceHandle: sourceHandleId || 'source',
|
||||
nextNodeId: target,
|
||||
nextNodeTargetHandle: targetHandleId || 'target',
|
||||
}}
|
||||
availableBlocksTypes={intersection(availablePrevBlocks, availableNextBlocks)}
|
||||
triggerClassName={() => 'hover:scale-150 transition-all'}
|
||||
/>
|
||||
|
||||
@ -49,17 +49,17 @@ export const ChecklistNodeGroup = memo(({
|
||||
<div
|
||||
key={sub.key}
|
||||
className={cn(
|
||||
'group/item flex items-center gap-2 rounded-lg px-1',
|
||||
'group/item flex items-start gap-2 rounded-lg px-1',
|
||||
goToEnabled && 'cursor-pointer hover:bg-state-base-hover',
|
||||
)}
|
||||
onClick={() => goToEnabled && onItemClick(item)}
|
||||
>
|
||||
<ItemIndicator />
|
||||
<span className="min-w-0 grow truncate text-xs leading-4 text-text-warning">
|
||||
<span className="min-w-0 grow py-1 text-xs leading-4 text-text-warning">
|
||||
{sub.message}
|
||||
</span>
|
||||
{goToEnabled && (
|
||||
<div className="flex shrink-0 items-center gap-0.5 pr-0.5 opacity-0 transition-opacity duration-150 group-hover/item:opacity-100">
|
||||
<div className="flex shrink-0 items-center gap-0.5 pt-1 pr-0.5 opacity-0 transition-opacity duration-150 group-hover/item:opacity-100">
|
||||
<span className="text-xs leading-4 font-medium whitespace-nowrap text-text-accent">
|
||||
{t('panel.goToFix', { ns: 'workflow' })}
|
||||
</span>
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import type { Node, NodeOutPutVar, Var } from '../../types'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { useSnippetDetailStore } from '@/app/components/snippets/store'
|
||||
import { PipelineInputVarType } from '@/models/pipeline'
|
||||
import { BlockEnum, VarType } from '../../types'
|
||||
import useNodesAvailableVarList, { useGetNodesAvailableVarList } from '../use-nodes-available-var-list'
|
||||
|
||||
@ -42,6 +44,8 @@ const outputVars: NodeOutPutVar[] = [{
|
||||
describe('useNodesAvailableVarList', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
globalThis.history.pushState({}, '', '/')
|
||||
useSnippetDetailStore.getState().reset()
|
||||
mockGetBeforeNodesInSameBranchIncludeParent.mockImplementation((nodeId: string) => [createNode({ id: `before-${nodeId}` })])
|
||||
mockGetTreeLeafNodes.mockImplementation((nodeId: string) => [createNode({ id: `leaf-${nodeId}` })])
|
||||
mockGetNodeAvailableVars.mockReturnValue(outputVars)
|
||||
@ -76,7 +80,7 @@ describe('useNodesAvailableVarList', () => {
|
||||
expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('loop-1')
|
||||
expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('child-1')
|
||||
expect(result.current['loop-1']?.availableNodes.map(node => node.id)).toEqual(['before-loop-1', 'loop-1'])
|
||||
expect(result.current['child-1']?.availableVars).toBe(outputVars)
|
||||
expect(result.current['child-1']?.availableVars).toEqual(outputVars)
|
||||
expect(mockGetNodeAvailableVars).toHaveBeenNthCalledWith(2, expect.objectContaining({
|
||||
parentNode: loopNode,
|
||||
isChatMode: true,
|
||||
@ -86,6 +90,37 @@ describe('useNodesAvailableVarList', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
it('adds snippet input fields as virtual start variables on snippet canvases', () => {
|
||||
globalThis.history.pushState({}, '', '/snippets/snippet-1/orchestrate')
|
||||
useSnippetDetailStore.getState().setFields([{
|
||||
type: PipelineInputVarType.textInput,
|
||||
label: 'Topic',
|
||||
variable: 'topic',
|
||||
required: true,
|
||||
}])
|
||||
|
||||
const currentNode = createNode({ id: 'node-a' })
|
||||
|
||||
const { result } = renderHook(() => useNodesAvailableVarList([currentNode], {
|
||||
filterVar: () => true,
|
||||
}))
|
||||
|
||||
expect(result.current['node-a']?.availableNodes[0]).toEqual(expect.objectContaining({
|
||||
id: 'start',
|
||||
data: expect.objectContaining({
|
||||
type: BlockEnum.Start,
|
||||
}),
|
||||
}))
|
||||
expect(result.current['node-a']?.availableVars[0]).toEqual(expect.objectContaining({
|
||||
nodeId: 'start',
|
||||
isStartNode: true,
|
||||
vars: [expect.objectContaining({
|
||||
variable: 'topic',
|
||||
type: VarType.string,
|
||||
})],
|
||||
}))
|
||||
})
|
||||
|
||||
it('returns a callback version that can use leaf nodes or caller-provided nodes', () => {
|
||||
const firstNode = createNode({ id: 'node-a' })
|
||||
const secondNode = createNode({ id: 'node-b' })
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
import type { Node, NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useSnippetDetailStore } from '@/app/components/snippets/store'
|
||||
import {
|
||||
useIsChatMode,
|
||||
useWorkflow,
|
||||
useWorkflowVariables,
|
||||
} from '@/app/components/workflow/hooks'
|
||||
import {
|
||||
appendSnippetInputFieldVars,
|
||||
} from '@/app/components/workflow/nodes/_base/hooks/snippet-input-field-vars'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
|
||||
type Params = {
|
||||
@ -41,6 +46,8 @@ const useNodesAvailableVarList = (nodes: Node[], {
|
||||
onlyLeafNodeVar: false,
|
||||
filterVar: () => true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const snippetInputFields = useSnippetDetailStore(s => s.fields)
|
||||
const { getTreeLeafNodes, getBeforeNodesInSameBranchIncludeParent } = useWorkflow()
|
||||
const { getNodeAvailableVars } = useWorkflowVariables()
|
||||
const isChatMode = useIsChatMode()
|
||||
@ -52,23 +59,31 @@ const useNodesAvailableVarList = (nodes: Node[], {
|
||||
const availableNodes = passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranchIncludeParent(nodeId))
|
||||
if (node.data.type === BlockEnum.Loop)
|
||||
availableNodes.push(node)
|
||||
const snippetInputFieldAvailability = appendSnippetInputFieldVars({
|
||||
availableNodes,
|
||||
fields: snippetInputFields,
|
||||
title: t('panelTitle', { ns: 'snippet' }),
|
||||
})
|
||||
|
||||
const {
|
||||
parentNode: iterationNode,
|
||||
} = getNodeInfo(nodeId, nodes)
|
||||
|
||||
const availableVars = getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
})
|
||||
const availableVars = [
|
||||
...snippetInputFieldAvailability.availableVars,
|
||||
...getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
}),
|
||||
]
|
||||
const result = {
|
||||
node,
|
||||
availableVars,
|
||||
availableNodes,
|
||||
availableNodes: snippetInputFieldAvailability.availableNodes,
|
||||
}
|
||||
nodeAvailabilityMap[nodeId] = result
|
||||
})
|
||||
@ -76,6 +91,8 @@ const useNodesAvailableVarList = (nodes: Node[], {
|
||||
}
|
||||
|
||||
export const useGetNodesAvailableVarList = () => {
|
||||
const { t } = useTranslation()
|
||||
const snippetInputFields = useSnippetDetailStore(s => s.fields)
|
||||
const { getTreeLeafNodes, getBeforeNodesInSameBranchIncludeParent } = useWorkflow()
|
||||
const { getNodeAvailableVars } = useWorkflowVariables()
|
||||
const isChatMode = useIsChatMode()
|
||||
@ -96,28 +113,36 @@ export const useGetNodesAvailableVarList = () => {
|
||||
const availableNodes = passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranchIncludeParent(nodeId))
|
||||
if (node.data.type === BlockEnum.Loop)
|
||||
availableNodes.push(node)
|
||||
const snippetInputFieldAvailability = appendSnippetInputFieldVars({
|
||||
availableNodes,
|
||||
fields: snippetInputFields,
|
||||
title: t('panelTitle', { ns: 'snippet' }),
|
||||
})
|
||||
|
||||
const {
|
||||
parentNode: iterationNode,
|
||||
} = getNodeInfo(nodeId, nodes)
|
||||
|
||||
const availableVars = getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
})
|
||||
const availableVars = [
|
||||
...snippetInputFieldAvailability.availableVars,
|
||||
...getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
}),
|
||||
]
|
||||
const result = {
|
||||
node,
|
||||
availableVars,
|
||||
availableNodes,
|
||||
availableNodes: snippetInputFieldAvailability.availableNodes,
|
||||
}
|
||||
nodeAvailabilityMap[nodeId] = result
|
||||
})
|
||||
return nodeAvailabilityMap
|
||||
}, [getTreeLeafNodes, getBeforeNodesInSameBranchIncludeParent, getNodeAvailableVars, isChatMode])
|
||||
}, [getTreeLeafNodes, getBeforeNodesInSameBranchIncludeParent, getNodeAvailableVars, isChatMode, snippetInputFields, t])
|
||||
return {
|
||||
getNodesAvailableVarList,
|
||||
}
|
||||
|
||||
@ -91,6 +91,10 @@ const Add = ({
|
||||
onOpenChange={handleOpenChange}
|
||||
disabled={nodesReadOnly}
|
||||
onSelect={handleSelect}
|
||||
snippetInsertPayload={{
|
||||
prevNodeId: nodeId,
|
||||
prevNodeSourceHandle: sourceHandle,
|
||||
}}
|
||||
placement="top"
|
||||
offset={0}
|
||||
trigger={renderTrigger}
|
||||
|
||||
@ -110,6 +110,10 @@ export const NodeTargetHandle = memo(({
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
onSelect={handleSelect}
|
||||
snippetInsertPayload={{
|
||||
nextNodeId: id,
|
||||
nextNodeTargetHandle: handleId,
|
||||
}}
|
||||
asChild
|
||||
placement="left"
|
||||
triggerClassName={open => `
|
||||
@ -229,6 +233,10 @@ export const NodeSourceHandle = memo(({
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
onSelect={handleSelect}
|
||||
snippetInsertPayload={{
|
||||
prevNodeId: id,
|
||||
prevNodeSourceHandle: handleId,
|
||||
}}
|
||||
asChild
|
||||
triggerClassName={open => `
|
||||
absolute top-0 left-0 opacity-0 pointer-events-none transition-opacity duration-150
|
||||
|
||||
@ -0,0 +1,108 @@
|
||||
import type { InputVarType, Node, NodeOutPutVar } from '@/app/components/workflow/types'
|
||||
import type { SnippetInputField } from '@/models/snippet'
|
||||
import { NODE_WIDTH } from '@/app/components/workflow/constants'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { PipelineInputVarType } from '@/models/pipeline'
|
||||
import { inputVarTypeToVarType } from '../../data-source/utils'
|
||||
|
||||
export const SNIPPET_INPUT_FIELD_NODE_ID = 'start'
|
||||
|
||||
export const isSnippetCanvas = () => {
|
||||
if (typeof globalThis.location === 'undefined')
|
||||
return false
|
||||
|
||||
return /^\/snippets\/[^/]+\/orchestrate/.test(globalThis.location.pathname)
|
||||
}
|
||||
|
||||
const toWorkflowInputType = (type: SnippetInputField['type']) => type as unknown as InputVarType
|
||||
|
||||
export const buildSnippetInputFieldNode = (
|
||||
fields: SnippetInputField[],
|
||||
title: string,
|
||||
): Node | undefined => {
|
||||
const variables = fields.filter(field => !!field.variable)
|
||||
|
||||
if (!variables.length)
|
||||
return undefined
|
||||
|
||||
return {
|
||||
id: SNIPPET_INPUT_FIELD_NODE_ID,
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
width: NODE_WIDTH,
|
||||
height: 80,
|
||||
data: {
|
||||
title,
|
||||
desc: '',
|
||||
type: BlockEnum.Start,
|
||||
variables: variables.map(field => ({
|
||||
type: toWorkflowInputType(field.type),
|
||||
label: field.label,
|
||||
variable: field.variable,
|
||||
max_length: field.max_length,
|
||||
default: field.default_value,
|
||||
required: field.required,
|
||||
options: field.options,
|
||||
placeholder: field.placeholder,
|
||||
unit: field.unit,
|
||||
allowed_file_upload_methods: field.allowed_file_upload_methods,
|
||||
allowed_file_types: field.allowed_file_types,
|
||||
allowed_file_extensions: field.allowed_file_extensions,
|
||||
})),
|
||||
},
|
||||
} as Node
|
||||
}
|
||||
|
||||
export const buildSnippetInputFieldVars = (
|
||||
fields: SnippetInputField[],
|
||||
title: string,
|
||||
): NodeOutPutVar | undefined => {
|
||||
const vars = fields
|
||||
.filter(field => !!field.variable)
|
||||
.map(field => ({
|
||||
variable: field.variable,
|
||||
type: inputVarTypeToVarType(field.type as PipelineInputVarType),
|
||||
isParagraph: field.type === PipelineInputVarType.paragraph,
|
||||
isSelect: field.type === PipelineInputVarType.select,
|
||||
options: field.options,
|
||||
required: field.required,
|
||||
des: field.label,
|
||||
}))
|
||||
|
||||
if (!vars.length)
|
||||
return undefined
|
||||
|
||||
return {
|
||||
nodeId: SNIPPET_INPUT_FIELD_NODE_ID,
|
||||
title,
|
||||
vars,
|
||||
isStartNode: true,
|
||||
}
|
||||
}
|
||||
|
||||
export const appendSnippetInputFieldVars = ({
|
||||
availableNodes,
|
||||
fields,
|
||||
title,
|
||||
}: {
|
||||
availableNodes: Node[]
|
||||
fields: SnippetInputField[]
|
||||
title: string
|
||||
}) => {
|
||||
const shouldAppendSnippetInputFields = isSnippetCanvas()
|
||||
&& fields.length > 0
|
||||
&& !availableNodes.some(node => node.data.type === BlockEnum.Start)
|
||||
const snippetInputFieldNode = shouldAppendSnippetInputFields
|
||||
? buildSnippetInputFieldNode(fields, title)
|
||||
: undefined
|
||||
const snippetInputFieldVars = shouldAppendSnippetInputFields
|
||||
? buildSnippetInputFieldVars(fields, title)
|
||||
: undefined
|
||||
|
||||
return {
|
||||
availableNodes: snippetInputFieldNode
|
||||
? [snippetInputFieldNode, ...availableNodes]
|
||||
: availableNodes,
|
||||
availableVars: snippetInputFieldVars ? [snippetInputFieldVars] : [],
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,6 @@
|
||||
import type { Node, NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useSnippetDetailStore } from '@/app/components/snippets/store'
|
||||
import {
|
||||
useIsChatMode,
|
||||
useWorkflow,
|
||||
@ -7,6 +9,7 @@ import {
|
||||
import { useStore as useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { inputVarTypeToVarType } from '../../data-source/utils'
|
||||
import { appendSnippetInputFieldVars } from './snippet-input-field-vars'
|
||||
import useNodeInfo from './use-node-info'
|
||||
|
||||
type Params = {
|
||||
@ -28,10 +31,17 @@ const useAvailableVarList = (nodeId: string, {
|
||||
onlyLeafNodeVar: false,
|
||||
filterVar: () => true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const snippetInputFields = useSnippetDetailStore(s => s.fields)
|
||||
const { getTreeLeafNodes, getNodeById, getBeforeNodesInSameBranchIncludeParent } = useWorkflow()
|
||||
const { getNodeAvailableVars } = useWorkflowVariables()
|
||||
const isChatMode = useIsChatMode()
|
||||
const availableNodes = passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranchIncludeParent(nodeId))
|
||||
const snippetInputFieldAvailability = appendSnippetInputFieldVars({
|
||||
availableNodes,
|
||||
fields: snippetInputFields,
|
||||
title: t('panelTitle', { ns: 'snippet' }),
|
||||
})
|
||||
const {
|
||||
parentNode: iterationNode,
|
||||
} = useNodeInfo(nodeId)
|
||||
@ -63,20 +73,24 @@ const useAvailableVarList = (nodeId: string, {
|
||||
})
|
||||
}
|
||||
}
|
||||
const availableVars = [...getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
}), ...dataSourceRagVars]
|
||||
const availableVars = [
|
||||
...snippetInputFieldAvailability.availableVars,
|
||||
...getNodeAvailableVars({
|
||||
parentNode: iterationNode,
|
||||
beforeNodes: availableNodes,
|
||||
isChatMode,
|
||||
filterVar,
|
||||
hideEnv,
|
||||
hideChatVar,
|
||||
}),
|
||||
...dataSourceRagVars,
|
||||
]
|
||||
|
||||
return {
|
||||
availableVars,
|
||||
availableNodes,
|
||||
availableNodes: snippetInputFieldAvailability.availableNodes,
|
||||
availableNodesWithParent: [
|
||||
...availableNodes,
|
||||
...snippetInputFieldAvailability.availableNodes,
|
||||
...(isDataSourceNode ? [currNode] : []),
|
||||
],
|
||||
}
|
||||
|
||||
@ -68,6 +68,10 @@ const AddBlock = ({
|
||||
<BlockSelector
|
||||
disabled={nodesReadOnly}
|
||||
onSelect={handleSelect}
|
||||
snippetInsertPayload={{
|
||||
prevNodeId: iterationNodeData.start_node_id,
|
||||
prevNodeSourceHandle: 'source',
|
||||
}}
|
||||
trigger={renderTriggerElement}
|
||||
triggerInnerClassName="inline-flex"
|
||||
popupClassName="min-w-[256px]!"
|
||||
|
||||
@ -69,6 +69,10 @@ const AddBlock = ({
|
||||
<BlockSelector
|
||||
disabled={nodesReadOnly}
|
||||
onSelect={handleSelect}
|
||||
snippetInsertPayload={{
|
||||
prevNodeId: loopNodeData.start_node_id,
|
||||
prevNodeSourceHandle: 'source',
|
||||
}}
|
||||
trigger={renderTriggerElement}
|
||||
triggerInnerClassName="inline-flex"
|
||||
popupClassName="min-w-[256px]!"
|
||||
|
||||
@ -51,6 +51,10 @@ const InsertBlock = ({
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
snippetInsertPayload={{
|
||||
nextNodeId: startNodeId,
|
||||
nextNodeTargetHandle: 'target',
|
||||
}}
|
||||
availableBlocksTypes={availableBlocksTypes}
|
||||
triggerClassName={() => 'hover:scale-125 transition-all'}
|
||||
/>
|
||||
|
||||
@ -5,7 +5,7 @@ import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import RemoveButton from '@/app/components/workflow/nodes/_base/components/remove-button'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import VariableTypeSelector from '@/app/components/workflow/panel/chat-variable-panel/components/variable-type-select'
|
||||
import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type'
|
||||
|
||||
@ -116,7 +116,7 @@ const ObjectValueItem: FC<Props> = ({
|
||||
{/* Value */}
|
||||
<div className="relative w-[230px]">
|
||||
<input
|
||||
className="block h-7 w-full appearance-none px-2 text-text-secondary caret-primary-600 outline-hidden system-xs-regular placeholder:text-components-input-text-placeholder placeholder:system-xs-regular hover:bg-state-base-hover focus:bg-components-input-bg-active"
|
||||
className="block h-7 w-full appearance-none px-2 pr-9 system-xs-regular text-text-secondary caret-primary-600 outline-hidden placeholder:system-xs-regular placeholder:text-components-input-text-placeholder hover:bg-state-base-hover focus:bg-components-input-bg-active"
|
||||
placeholder={t('chatVariable.modal.objectValue', { ns: 'workflow' }) || ''}
|
||||
value={list[index].value}
|
||||
onChange={handleValueChange(index)}
|
||||
@ -125,10 +125,15 @@ const ObjectValueItem: FC<Props> = ({
|
||||
type={list[index].type === ChatVarType.Number ? 'number' : 'text'}
|
||||
/>
|
||||
{list.length > 1 && !isFocus && (
|
||||
<RemoveButton
|
||||
className="absolute right-1 top-0.5 z-10 hidden group-hover:block"
|
||||
onClick={handleItemRemove(index)}
|
||||
/>
|
||||
<div className="absolute top-0.5 right-1 z-10">
|
||||
<ActionButton
|
||||
size="m"
|
||||
className="group hover:bg-state-destructive-hover!"
|
||||
onClick={handleItemRemove(index)}
|
||||
>
|
||||
<span className="i-ri-delete-bin-line h-4 w-4 text-text-tertiary group-hover:text-text-destructive" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -498,7 +498,7 @@ export type ChildNodeTypeCount = {
|
||||
[key: string]: number
|
||||
}
|
||||
|
||||
const TRIGGER_NODE_TYPES = [
|
||||
export const TRIGGER_NODE_TYPES = [
|
||||
BlockEnum.TriggerSchedule,
|
||||
BlockEnum.TriggerWebhook,
|
||||
BlockEnum.TriggerPlugin,
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
"deleteConfirmTitle": "Delete Snippet?",
|
||||
"deleteFailed": "Failed to delete snippet",
|
||||
"deleted": "Snippet deleted",
|
||||
"draft": "Draft",
|
||||
"editDialogTitle": "Edit Snippet Info",
|
||||
"editDone": "Snippet info updated",
|
||||
"editFailed": "Failed to update snippet info",
|
||||
@ -32,6 +33,8 @@
|
||||
"sectionOrchestrate": "Orchestrate",
|
||||
"testRunButton": "Test run",
|
||||
"typeLabel": "Snippet",
|
||||
"unknownUser": "User",
|
||||
"updatedBy": "{{name}} updated {{time}}",
|
||||
"usageCount": "Used {{count}} times",
|
||||
"variableInspect": "Variable Inspect"
|
||||
}
|
||||
|
||||
@ -35,8 +35,8 @@
|
||||
"stepOne.uploader.cancel": "Cancel",
|
||||
"stepOne.uploader.change": "Change",
|
||||
"stepOne.uploader.failed": "Upload failed",
|
||||
"stepOne.uploader.tip": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each.",
|
||||
"stepOne.uploader.tipWithTotalLimit": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each. Max total {{totalCount}} files.",
|
||||
"stepOne.uploader.tip": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand.",
|
||||
"stepOne.uploader.tipWithTotalLimit": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand. Maximaal {{totalCount}} bestanden in totaal.",
|
||||
"stepOne.uploader.title": "Upload file",
|
||||
"stepOne.uploader.validation.count": "Multiple files not supported",
|
||||
"stepOne.uploader.validation.filesNumber": "You have reached the batch upload limit of {{filesNumber}}.",
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
"deleteConfirmTitle": "删除 Snippet?",
|
||||
"deleteFailed": "删除 Snippet 失败",
|
||||
"deleted": "Snippet 已删除",
|
||||
"draft": "草稿",
|
||||
"editDialogTitle": "编辑 Snippet 信息",
|
||||
"editDone": "Snippet 信息已更新",
|
||||
"editFailed": "更新 Snippet 信息失败",
|
||||
@ -32,6 +33,8 @@
|
||||
"sectionOrchestrate": "编排",
|
||||
"testRunButton": "测试运行",
|
||||
"typeLabel": "Snippet",
|
||||
"unknownUser": "用户",
|
||||
"updatedBy": "{{name}} 更新于 {{time}}",
|
||||
"usageCount": "已使用 {{count}} 次",
|
||||
"variableInspect": "变量查看"
|
||||
}
|
||||
|
||||
@ -8,7 +8,6 @@ export type SnippetListItem = {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
author: string
|
||||
updatedAt: string
|
||||
usage: string
|
||||
icon: string
|
||||
@ -21,7 +20,6 @@ export type SnippetDetail = {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
author: string
|
||||
updatedAt: string
|
||||
usage: string
|
||||
icon: string
|
||||
|
||||
@ -15,7 +15,6 @@ const getSnippetListMock = (): SnippetListItem[] => ([
|
||||
id: 'snippet-1',
|
||||
name: 'Tone Rewriter',
|
||||
description: 'Rewrites rough drafts into a concise, professional tone for internal stakeholder updates.',
|
||||
author: 'Evan',
|
||||
updatedAt: 'Updated 2h ago',
|
||||
usage: 'Used 19 times',
|
||||
icon: '🪄',
|
||||
@ -28,7 +27,6 @@ const createSnippetMock = (snippetId: string): SnippetListItem => ({
|
||||
id: snippetId,
|
||||
name: 'Tone Rewriter',
|
||||
description: 'Rewrites rough drafts into a concise, professional tone for internal stakeholder updates.',
|
||||
author: 'Evan',
|
||||
updatedAt: 'Updated 2h ago',
|
||||
usage: 'Used 19 times',
|
||||
icon: '🪄',
|
||||
|
||||
@ -70,7 +70,6 @@ const toSnippetListItem = (snippet: SnippetSummary): SnippetListItemUIModel => {
|
||||
id: snippet.id,
|
||||
name: snippet.name,
|
||||
description: snippet.description,
|
||||
author: '',
|
||||
updatedAt: formatTimestamp(snippet.updated_at),
|
||||
usage: String(snippet.use_count ?? 0),
|
||||
icon: getSnippetIcon(snippet.icon_info),
|
||||
@ -201,6 +200,20 @@ export const useDeleteSnippetMutation = () => {
|
||||
})
|
||||
}
|
||||
|
||||
export const useIncrementSnippetUseCountMutation = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation({
|
||||
...consoleQuery.snippets.incrementUseCount.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.snippets.key(),
|
||||
})
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
export const useExportSnippetMutation = () => {
|
||||
return useMutation<string, Error, { snippetId: string, include?: boolean }>({
|
||||
mutationFn: ({ snippetId, include = false }) => {
|
||||
|
||||
@ -22,8 +22,9 @@ export type Snippet = {
|
||||
icon_info: SnippetIconInfo
|
||||
input_fields: SnippetInputField[]
|
||||
created_at: number
|
||||
created_by: string
|
||||
updated_at: number
|
||||
author: string
|
||||
updated_by: string
|
||||
}
|
||||
|
||||
export type SnippetListItem = Omit<Snippet, 'version' | 'input_fields'>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user