mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration2
This commit is contained in:
commit
305a4b65cb
@ -112,6 +112,7 @@ S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
S3_ADDRESS_STYLE=auto
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
|
||||
@ -34,9 +34,10 @@ from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
|
||||
@ -17,8 +17,9 @@ from fields.app_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from models.account import AccountStatus
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
|
||||
|
||||
class InnerAppDSLImportPayload(BaseModel):
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
pks = []
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
for document, embedding in zip(documents, embeddings):
|
||||
pk = uuid4()
|
||||
session.execute(
|
||||
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
|
||||
),
|
||||
)
|
||||
pks.append(pk)
|
||||
session.commit()
|
||||
|
||||
return pks
|
||||
|
||||
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self._client) as session:
|
||||
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self):
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
|
||||
@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel):
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
match self.type:
|
||||
case ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"authentication", self.authentication.model_dump() if self.authentication else None
|
||||
)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
case ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
case _:
|
||||
pass
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -166,13 +167,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -190,13 +185,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
@ -210,13 +199,7 @@ class ToolFileManager:
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -234,13 +217,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
@ -7,14 +7,13 @@ from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
@ -17,18 +16,6 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
session.query(Workflow)
|
||||
workflow: Workflow | None = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return self.tools
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
db_provider: WorkflowToolProvider | None = session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
@ -68,46 +68,49 @@ class EnterpriseMetricHandler:
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
match case:
|
||||
case TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
case TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
case TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
case TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
case TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
case TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
case TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
case TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
case TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
case TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
case TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION:
|
||||
pass
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
@ -813,56 +813,32 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dump_optional(value: Any) -> str | None:
|
||||
return json.dumps(value) if value else None
|
||||
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
self.suggested_questions = self._dump_optional(model_config.get("suggested_questions"))
|
||||
self.suggested_questions_after_answer = self._dump_optional(
|
||||
model_config.get("suggested_questions_after_answer")
|
||||
)
|
||||
self.speech_to_text = self._dump_optional(model_config.get("speech_to_text"))
|
||||
self.text_to_speech = self._dump_optional(model_config.get("text_to_speech"))
|
||||
self.more_like_this = self._dump_optional(model_config.get("more_like_this"))
|
||||
self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance"))
|
||||
self.external_data_tools = self._dump_optional(model_config.get("external_data_tools"))
|
||||
self.model = self._dump_optional(model_config.get("model"))
|
||||
self.user_input_form = self._dump_optional(model_config.get("user_input_form"))
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.agent_mode = self._dump_optional(model_config.get("agent_mode"))
|
||||
self.retriever_resource = self._dump_optional(model_config.get("retriever_resource"))
|
||||
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
|
||||
self.chat_prompt_config = (
|
||||
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
|
||||
)
|
||||
self.completion_prompt_config = (
|
||||
json.dumps(model_config.get("completion_prompt_config"))
|
||||
if model_config.get("completion_prompt_config")
|
||||
else None
|
||||
)
|
||||
self.dataset_configs = (
|
||||
json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None
|
||||
)
|
||||
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
|
||||
self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config"))
|
||||
self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config"))
|
||||
self.dataset_configs = self._dump_optional(model_config.get("dataset_configs"))
|
||||
self.file_upload = self._dump_optional(model_config.get("file_upload"))
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
|
||||
@ -3,7 +3,6 @@ import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
@ -75,10 +63,6 @@ class Import(BaseModel):
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -53,13 +53,12 @@ class DatasourceProviderService:
|
||||
"""
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@ -109,7 +108,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get credential by id
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
@ -156,7 +155,6 @@ class DatasourceProviderService:
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
@ -174,7 +172,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -224,7 +222,6 @@ class DatasourceProviderService:
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
session.commit()
|
||||
|
||||
return real_credentials_list
|
||||
|
||||
@ -234,7 +231,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource provider name
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
@ -266,7 +263,6 @@ class DatasourceProviderService:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
session.commit()
|
||||
return
|
||||
|
||||
def set_default_datasource_provider(
|
||||
@ -275,7 +271,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
set default datasource provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
@ -300,7 +296,6 @@ class DatasourceProviderService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
@ -315,7 +310,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
@ -349,7 +344,6 @@ class DatasourceProviderService:
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
@ -488,7 +482,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource oauth provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
@ -535,7 +529,6 @@ class DatasourceProviderService:
|
||||
target_provider.expires_at = expire_at
|
||||
target_provider.encrypted_credentials = credentials
|
||||
target_provider.avatar_url = avatar_url or target_provider.avatar_url
|
||||
session.commit()
|
||||
|
||||
def add_datasource_oauth_provider(
|
||||
self,
|
||||
@ -550,7 +543,7 @@ class DatasourceProviderService:
|
||||
add datasource oauth provider
|
||||
"""
|
||||
credential_type = CredentialType.OAUTH2
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
||||
with redis_client.lock(lock, timeout=60):
|
||||
db_provider_name = name
|
||||
@ -604,7 +597,6 @@ class DatasourceProviderService:
|
||||
expires_at=expire_at,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def add_datasource_api_key_provider(
|
||||
self,
|
||||
@ -623,7 +615,7 @@ class DatasourceProviderService:
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
db_provider_name = name or self.generate_next_datasource_provider_name(
|
||||
@ -670,7 +662,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
||||
"""
|
||||
@ -926,7 +917,7 @@ class DatasourceProviderService:
|
||||
update datasource credentials.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -980,7 +971,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials[key] = value
|
||||
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
session.commit()
|
||||
|
||||
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
|
||||
21
api/services/entities/dsl_entities.py
Normal file
21
api/services/entities/dsl_entities.py
Normal file
@ -0,0 +1,21 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
@ -1,3 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -8,10 +9,10 @@ class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
return session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -24,10 +25,10 @@ class PluginAutoUpgradeService:
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
@ -51,10 +52,10 @@ class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -8,7 +9,9 @@ class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
return session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
@ -17,8 +20,8 @@ class PluginPermissionService:
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
|
||||
@ -555,7 +555,7 @@ class RagPipelineService:
|
||||
workflow_node_execution.id
|
||||
)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -569,7 +569,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
@ -1325,7 +1324,7 @@ class RagPipelineService:
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -1339,7 +1338,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
|
||||
@ -20,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -37,7 +37,7 @@ from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.enums import CollectionBindingType, DatasetRuntimeMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.app_dsl_service import ImportMode, ImportStatus
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
KnowledgeConfiguration,
|
||||
@ -64,10 +64,6 @@ class RagPipelineImportInfo(BaseModel):
|
||||
dataset_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -46,13 +46,12 @@ class BuiltinToolManageService:
|
||||
delete custom oauth client params
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -150,7 +149,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
@ -203,9 +202,7 @@ class BuiltinToolManageService:
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@ -222,7 +219,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
@ -281,9 +278,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
@ -379,7 +374,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
@ -393,7 +388,6 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -409,7 +403,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
@ -422,7 +416,6 @@ class BuiltinToolManageService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -654,7 +647,7 @@ class BuiltinToolManageService:
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@ -690,7 +683,6 @@ class BuiltinToolManageService:
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -48,21 +48,25 @@ class ToolTransformService:
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
return ""
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
case ToolProviderType.API | ToolProviderType.WORKFLOW:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
case ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
return ""
|
||||
case _:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
|
||||
@ -1075,9 +1075,8 @@ class DraftVariableSaver:
|
||||
)
|
||||
engine = bind = self._session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
|
||||
session.add(variable_file)
|
||||
session.commit()
|
||||
|
||||
return truncation_result.result, variable_file
|
||||
|
||||
|
||||
@ -909,7 +909,7 @@ class WorkflowService:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -920,7 +920,6 @@ class WorkflowService:
|
||||
user=account,
|
||||
)
|
||||
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
|
||||
session.commit()
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
@ -1049,7 +1048,7 @@ class WorkflowService:
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -1060,7 +1059,6 @@ class WorkflowService:
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
draft_var_saver.save(outputs=outputs, process_data={})
|
||||
session.commit()
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -112,7 +112,9 @@ def clean_dataset_task(
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
@ -150,20 +152,22 @@ def clean_dataset_task(
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
|
||||
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
# delete dataset metadata
|
||||
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
|
||||
session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
)
|
||||
)
|
||||
# delete files
|
||||
if documents:
|
||||
file_ids = []
|
||||
@ -174,7 +178,7 @@ def clean_dataset_task(
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_ids.append(file_id)
|
||||
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
storage.delete(file.key)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
|
||||
@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
tenant_id = None
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||
return
|
||||
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
except Exception as e:
|
||||
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
|
||||
@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
|
||||
return _session
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _sessionmaker_factory(calls, execute_results=None):
|
||||
def _sessionmaker(*args, **kwargs):
|
||||
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
|
||||
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
|
||||
return _sessionmaker
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
"""Patch both Session and sessionmaker on the module with the same call tracker."""
|
||||
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
vector.create_collection = MagicMock()
|
||||
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
|
||||
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
|
||||
|
||||
class _InsertBuilder:
|
||||
def __init__(self, table):
|
||||
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
],
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
|
||||
vector.delete()
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
runtime_calls = []
|
||||
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
|
||||
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
|
||||
]
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
|
||||
@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url=None)
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
|
||||
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = tool_file
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, tool_file]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
size=12,
|
||||
)
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
|
||||
@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
|
||||
parameter_configurations=[],
|
||||
)
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
fake_cm = MagicMock()
|
||||
fake_cm.__enter__.return_value = session
|
||||
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
assert controller.get_tools("tenant-1") == []
|
||||
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
session.scalar.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -31,7 +31,7 @@ class TestGetStrategy:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -44,9 +44,9 @@ class TestGetStrategy:
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
@ -65,7 +65,7 @@ class TestChangeStrategy:
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -90,11 +90,12 @@ class TestChangeStrategy:
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.select"),
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
@ -113,9 +114,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -131,9 +132,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -148,9 +149,9 @@ class TestExcludePlugin:
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -167,9 +168,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
session.scalar.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -31,7 +31,7 @@ class TestGetPermission:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -44,9 +44,9 @@ class TestGetPermission:
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -59,7 +59,7 @@ class TestChangePermission:
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -40,7 +40,10 @@ class TestDatasourceProviderService:
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
"""
|
||||
with patch("services.datasource_provider_service.Session") as mock_cls:
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
@ -63,6 +66,8 @@ class TestDatasourceProviderService:
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
yield sess
|
||||
|
||||
@ -266,7 +271,6 @@ class TestDatasourceProviderService:
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
@ -333,7 +337,6 @@ class TestDatasourceProviderService:
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -352,7 +355,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# set_default_datasource_provider (lines 277-303)
|
||||
@ -370,7 +372,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().first.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_encrypter (lines 404-420)
|
||||
@ -460,7 +461,6 @@ class TestDatasourceProviderService:
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
@ -512,7 +512,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -523,7 +522,6 @@ class TestDatasourceProviderService:
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -571,7 +569,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
@ -747,7 +744,6 @@ class TestDatasourceProviderService:
|
||||
# encrypter must have been called with the new secret value
|
||||
self._enc.encrypt_token.assert_called()
|
||||
# commit must be called exactly once
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_datasource_credentials (lines 980-997)
|
||||
@ -758,7 +754,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
|
||||
@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
|
||||
return session
|
||||
|
||||
|
||||
def _mock_sessionmaker(mock_sm_cls):
|
||||
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
|
||||
session = MagicMock()
|
||||
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
class TestDeleteCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
|
||||
class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.delete.assert_called_once_with(db_provider)
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
|
||||
@ -187,14 +193,13 @@ class TestSetDefaultProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert target.is_default is True
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.CredentialType")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
|
||||
@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -60,12 +60,6 @@ def mock_db_session():
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
# Setup query chain
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 0
|
||||
|
||||
# Setup scalars for select queries
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion:
|
||||
- Pipeline record is deleted
|
||||
- Related workflow record is deleted
|
||||
"""
|
||||
# Arrange
|
||||
mock_query = mock_db_session.session.query.return_value
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion:
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
# Assert - verify delete was called for pipeline-related queries
|
||||
# The actual count depends on total queries, but pipeline deletion should add 2 more
|
||||
assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow
|
||||
# Assert - verify execute was called for delete operations
|
||||
# 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8
|
||||
assert mock_db_session.session.execute.call_count >= 8
|
||||
|
||||
def test_clean_dataset_task_without_pipeline_id(
|
||||
self,
|
||||
@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion:
|
||||
Expected behavior:
|
||||
- Pipeline and workflow deletion queries are not executed
|
||||
"""
|
||||
# Arrange
|
||||
mock_query = mock_db_session.session.query.return_value
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion:
|
||||
pipeline_id=None,
|
||||
)
|
||||
|
||||
# Assert - verify delete was called only for base queries (5 times)
|
||||
assert mock_query.delete.call_count == 5
|
||||
# Assert - verify execute was called for delete operations
|
||||
# 1 attachment JOIN query + 5 base deletes = 6
|
||||
assert mock_db_session.session.execute.call_count == 6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset):
|
||||
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
session.scalar.side_effect = [mock_document, mock_dataset]
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
@ -242,14 +242,13 @@ class TestDataSourceInfoSerialization:
|
||||
# DB session mock — shared across all ``session_factory.create_session()`` calls
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
# .where() path: session 1 reads document + dataset, session 2 reads dataset
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
# All .first() calls are now session.scalar() — ordered by call sequence:
|
||||
# session 1: document + dataset, session 2: dataset (clean), session 3: document (update),
|
||||
# session 4: document (indexing)
|
||||
session.scalar.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
]
|
||||
# .filter_by() path: session 3 (update), session 4 (indexing)
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@ -1519,7 +1519,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = ">=0.55.1" },
|
||||
{ name = "langfuse", specifier = ">=3.0.0,<5.0.0" },
|
||||
{ name = "langsmith", specifier = "~=0.7.16" },
|
||||
{ name = "litellm", specifier = "==1.82.6" },
|
||||
{ name = "litellm", specifier = "==1.83.0" },
|
||||
{ name = "markdown", specifier = "~=3.10.2" },
|
||||
{ name = "mlflow-skinny", specifier = ">=3.0.0" },
|
||||
{ name = "numpy", specifier = "~=1.26.4" },
|
||||
@ -3146,7 +3146,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.82.6"
|
||||
version = "1.83.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@ -3162,9 +3162,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/92/6ce9737554994ca8e536e5f4f6a87cc7c4774b656c9eb9add071caf7d54b/litellm-1.83.0.tar.gz", hash = "sha256:860bebc76c4bb27b4cf90b4a77acd66dba25aced37e3db98750de8a1766bfb7a", size = 17333062, upload-time = "2026-03-31T05:08:25.331Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/2c/a670cc050fcd6f45c6199eb99e259c73aea92edba8d5c2fc1b3686d36217/litellm-1.83.0-py3-none-any.whl", hash = "sha256:88c536d339248f3987571493015784671ba3f193a328e1ea6780dbebaa2094a8", size = 15610306, upload-time = "2026-03-31T05:08:21.987Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -476,6 +476,7 @@ S3_REGION=us-east-1
|
||||
S3_BUCKET_NAME=difyai
|
||||
S3_ACCESS_KEY=
|
||||
S3_SECRET_KEY=
|
||||
S3_ADDRESS_STYLE=auto
|
||||
# Whether to use AWS managed IAM roles for authenticating with the S3 service.
|
||||
# If set to false, the access key and secret key must be provided.
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
||||
@ -133,6 +133,7 @@ x-shared-env: &shared-api-worker-env
|
||||
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
|
||||
S3_ACCESS_KEY: ${S3_ACCESS_KEY:-}
|
||||
S3_SECRET_KEY: ${S3_SECRET_KEY:-}
|
||||
S3_ADDRESS_STYLE: ${S3_ADDRESS_STYLE:-auto}
|
||||
S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false}
|
||||
ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false}
|
||||
ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}
|
||||
|
||||
11
e2e/features/apps/create-chatbot-app.feature
Normal file
11
e2e/features/apps/create-chatbot-app.feature
Normal file
@ -0,0 +1,11 @@
|
||||
@apps @authenticated
|
||||
Feature: Create Chatbot app
|
||||
Scenario: Create a new Chatbot app and redirect to the configuration page
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I start creating a blank app
|
||||
And I expand the beginner app types
|
||||
And I select the "Chatbot" app type
|
||||
And I enter a unique E2E app name
|
||||
And I confirm app creation
|
||||
Then I should land on the app configuration page
|
||||
10
e2e/features/apps/create-workflow-app.feature
Normal file
10
e2e/features/apps/create-workflow-app.feature
Normal file
@ -0,0 +1,10 @@
|
||||
@apps @authenticated
|
||||
Feature: Create Workflow app
|
||||
Scenario: Create a new Workflow app and redirect to the workflow editor
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I start creating a blank app
|
||||
And I select the "Workflow" app type
|
||||
And I enter a unique E2E app name
|
||||
And I confirm app creation
|
||||
Then I should land on the workflow editor
|
||||
8
e2e/features/auth/sign-out.feature
Normal file
8
e2e/features/auth/sign-out.feature
Normal file
@ -0,0 +1,8 @@
|
||||
@auth @authenticated
|
||||
Feature: Sign out
|
||||
Scenario: Sign out from the apps console
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the apps console
|
||||
And I open the account menu
|
||||
And I sign out
|
||||
Then I should be on the sign-in page
|
||||
@ -24,6 +24,30 @@ When('I confirm app creation', async function (this: DifyWorld) {
|
||||
await createButton.click()
|
||||
})
|
||||
|
||||
When('I select the {string} app type', async function (this: DifyWorld, appType: string) {
|
||||
const dialog = this.getPage().getByRole('dialog')
|
||||
const appTypeTitle = dialog.getByText(appType, { exact: true })
|
||||
|
||||
await expect(appTypeTitle).toBeVisible()
|
||||
await appTypeTitle.click()
|
||||
})
|
||||
|
||||
When('I expand the beginner app types', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
const toggle = page.getByRole('button', { name: 'More basic app types' })
|
||||
|
||||
await expect(toggle).toBeVisible()
|
||||
await toggle.click()
|
||||
})
|
||||
|
||||
Then('I should land on the app editor', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
Then('I should land on the workflow editor', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/workflow(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
Then('I should land on the app configuration page', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/configuration(?:\?.*)?$/)
|
||||
})
|
||||
|
||||
25
e2e/features/step-definitions/auth/sign-out.steps.ts
Normal file
25
e2e/features/step-definitions/auth/sign-out.steps.ts
Normal file
@ -0,0 +1,25 @@
|
||||
import { Then, When } from '@cucumber/cucumber'
|
||||
import { expect } from '@playwright/test'
|
||||
import type { DifyWorld } from '../../support/world'
|
||||
|
||||
When('I open the account menu', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
const trigger = page.getByRole('button', { name: 'Account' })
|
||||
|
||||
await expect(trigger).toBeVisible()
|
||||
await trigger.click()
|
||||
})
|
||||
|
||||
When('I sign out', async function (this: DifyWorld) {
|
||||
const page = this.getPage()
|
||||
|
||||
await expect(page.getByText('Log out')).toBeVisible()
|
||||
await page.getByText('Log out').click()
|
||||
})
|
||||
|
||||
Then('I should be on the sign-in page', async function (this: DifyWorld) {
|
||||
await expect(this.getPage()).toHaveURL(/\/signin/)
|
||||
await expect(this.getPage().getByRole('button', { name: /^Sign in$/i })).toBeVisible({
|
||||
timeout: 30_000,
|
||||
})
|
||||
})
|
||||
1249
pnpm-lock.yaml
generated
1249
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@ -47,7 +47,7 @@ overrides:
|
||||
catalog:
|
||||
"@amplitude/analytics-browser": 2.38.1
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.6
|
||||
"@antfu/eslint-config": 8.0.0
|
||||
"@antfu/eslint-config": 8.1.1
|
||||
"@base-ui/react": 1.3.0
|
||||
"@chromatic-com/storybook": 5.1.1
|
||||
"@cucumber/cucumber": 12.7.0
|
||||
@ -73,8 +73,8 @@ catalog:
|
||||
"@mdx-js/react": 3.1.1
|
||||
"@mdx-js/rollup": 3.1.1
|
||||
"@monaco-editor/react": 4.7.0
|
||||
"@next/eslint-plugin-next": 16.2.2
|
||||
"@next/mdx": 16.2.2
|
||||
"@next/eslint-plugin-next": 16.2.3
|
||||
"@next/mdx": 16.2.3
|
||||
"@orpc/client": 1.13.13
|
||||
"@orpc/contract": 1.13.13
|
||||
"@orpc/openapi-client": 1.13.13
|
||||
@ -120,9 +120,9 @@ catalog:
|
||||
"@types/sortablejs": 1.15.9
|
||||
"@typescript-eslint/eslint-plugin": 8.58.1
|
||||
"@typescript-eslint/parser": 8.58.1
|
||||
"@typescript/native-preview": 7.0.0-dev.20260407.1
|
||||
"@typescript/native-preview": 7.0.0-dev.20260408.1
|
||||
"@vitejs/plugin-react": 6.0.1
|
||||
"@vitejs/plugin-rsc": 0.5.22
|
||||
"@vitejs/plugin-rsc": 0.5.23
|
||||
"@vitest/coverage-v8": 4.1.3
|
||||
abcjs: 6.6.2
|
||||
agentation: 3.0.2
|
||||
@ -146,7 +146,7 @@ catalog:
|
||||
emoji-mart: 5.6.0
|
||||
es-toolkit: 1.45.1
|
||||
eslint: 10.2.0
|
||||
eslint-markdown: 0.6.0
|
||||
eslint-markdown: 0.6.1
|
||||
eslint-plugin-better-tailwindcss: 4.3.2
|
||||
eslint-plugin-hyoban: 0.14.1
|
||||
eslint-plugin-markdown-preferences: 0.41.0
|
||||
@ -160,7 +160,7 @@ catalog:
|
||||
hono: 4.12.12
|
||||
html-entities: 2.6.0
|
||||
html-to-image: 1.11.13
|
||||
i18next: 26.0.3
|
||||
i18next: 26.0.4
|
||||
i18next-resources-to-backend: 1.2.1
|
||||
iconify-import-svg: 0.1.2
|
||||
immer: 11.1.4
|
||||
@ -170,7 +170,7 @@ catalog:
|
||||
js-yaml: 4.1.1
|
||||
jsonschema: 1.5.0
|
||||
katex: 0.16.45
|
||||
knip: 6.3.0
|
||||
knip: 6.3.1
|
||||
ky: 2.0.0
|
||||
lamejs: 1.2.1
|
||||
lexical: 0.42.0
|
||||
@ -179,24 +179,24 @@ catalog:
|
||||
mime: 4.1.0
|
||||
mitt: 3.0.1
|
||||
negotiator: 1.0.0
|
||||
next: 16.2.2
|
||||
next: 16.2.3
|
||||
next-themes: 0.4.6
|
||||
nuqs: 2.8.9
|
||||
pinyin-pro: 3.28.0
|
||||
postcss: 8.5.9
|
||||
postcss-js: 5.1.0
|
||||
qrcode.react: 4.2.0
|
||||
qs: 6.15.0
|
||||
react: 19.2.4
|
||||
qs: 6.15.1
|
||||
react: 19.2.5
|
||||
react-18-input-autosize: 3.0.0
|
||||
react-dom: 19.2.4
|
||||
react-dom: 19.2.5
|
||||
react-easy-crop: 5.5.7
|
||||
react-hotkeys-hook: 5.2.4
|
||||
react-i18next: 17.0.2
|
||||
react-multi-email: 1.0.25
|
||||
react-papaparse: 4.4.0
|
||||
react-pdf-highlighter: 8.0.0-rc.0
|
||||
react-server-dom-webpack: 19.2.4
|
||||
react-server-dom-webpack: 19.2.5
|
||||
react-sortablejs: 6.1.4
|
||||
react-textarea-autosize: 8.5.9
|
||||
reactflow: 11.11.4
|
||||
@ -221,7 +221,7 @@ catalog:
|
||||
unist-util-visit: 5.1.0
|
||||
use-context-selector: 2.0.0
|
||||
uuid: 13.0.0
|
||||
vinext: 0.0.40
|
||||
vinext: https://pkg.pr.new/vinext@adbf24d
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.16
|
||||
vite-plugin-inspect: 12.0.0-beta.1
|
||||
vite-plus: 0.1.16
|
||||
|
||||
224
web/__tests__/app-sidebar/dataset-info-flow.test.tsx
Normal file
224
web/__tests__/app-sidebar/dataset-info-flow.test.tsx
Normal file
@ -0,0 +1,224 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import DatasetInfo from '@/app/components/app-sidebar/dataset-info'
|
||||
import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockInvalidDatasetList = vi.fn()
|
||||
const mockInvalidDatasetDetail = vi.fn()
|
||||
const mockExportPipeline = vi.fn()
|
||||
const mockCheckIsUsedInApp = vi.fn()
|
||||
const mockDeleteDataset = vi.fn()
|
||||
const mockDownloadBlob = vi.fn()
|
||||
|
||||
let mockDataset: DataSet
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({
|
||||
dataset: mockDataset,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) =>
|
||||
selector({ isCurrentWorkspaceDatasetOperator: false }),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: () => 'indexing-technique',
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/knowledge/use-dataset', () => ({
|
||||
datasetDetailQueryKeyPrefix: ['dataset', 'detail'],
|
||||
useInvalidDatasetList: () => mockInvalidDatasetList,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-base', () => ({
|
||||
useInvalid: () => mockInvalidDatasetDetail,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-pipeline', () => ({
|
||||
useExportPipelineDSL: () => ({
|
||||
mutateAsync: mockExportPipeline,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args),
|
||||
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/download', () => ({
|
||||
downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/datasets/rename-modal', () => ({
|
||||
default: ({
|
||||
show,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: {
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
onSuccess: () => void
|
||||
}) => show
|
||||
? (
|
||||
<div data-testid="rename-dataset-modal">
|
||||
<button type="button" onClick={onSuccess}>rename-success</button>
|
||||
<button type="button" onClick={onClose}>rename-close</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
|
||||
id: 'dataset-1',
|
||||
name: 'Dataset Name',
|
||||
indexing_status: 'completed',
|
||||
icon_info: {
|
||||
icon: '📙',
|
||||
icon_background: '#FFF4ED',
|
||||
icon_type: 'emoji',
|
||||
icon_url: '',
|
||||
},
|
||||
description: 'Dataset description',
|
||||
permission: DatasetPermission.onlyMe,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
|
||||
created_by: 'user-1',
|
||||
updated_by: 'user-1',
|
||||
updated_at: 1690000000,
|
||||
app_count: 0,
|
||||
doc_form: ChunkingMode.text,
|
||||
document_count: 1,
|
||||
total_document_count: 1,
|
||||
word_count: 1000,
|
||||
provider: 'internal',
|
||||
embedding_model: 'text-embedding-3',
|
||||
embedding_model_provider: 'openai',
|
||||
embedding_available: true,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
retrieval_model: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
tags: [],
|
||||
external_knowledge_info: {
|
||||
external_knowledge_id: '',
|
||||
external_knowledge_api_id: '',
|
||||
external_knowledge_api_name: '',
|
||||
external_knowledge_api_endpoint: '',
|
||||
},
|
||||
external_retrieval_model: {
|
||||
top_k: 0,
|
||||
score_threshold: 0,
|
||||
score_threshold_enabled: false,
|
||||
},
|
||||
built_in_field_enabled: false,
|
||||
runtime_mode: 'rag_pipeline',
|
||||
pipeline_id: 'pipeline-1',
|
||||
enable_api: false,
|
||||
is_multimodal: false,
|
||||
is_published: true,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const openDropdown = () => {
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
}
|
||||
|
||||
describe('App Sidebar Dataset Info Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockDataset = createDataset()
|
||||
mockExportPipeline.mockResolvedValue({ data: 'pipeline: demo' })
|
||||
mockCheckIsUsedInApp.mockResolvedValue({ is_using: false })
|
||||
mockDeleteDataset.mockResolvedValue({})
|
||||
})
|
||||
|
||||
it('exports the published pipeline from the dropdown menu', async () => {
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
expect(screen.getByText('Dataset Name')).toBeInTheDocument()
|
||||
|
||||
openDropdown()
|
||||
fireEvent.click(await screen.findByText('datasetPipeline.operations.exportPipeline'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockExportPipeline).toHaveBeenCalledWith({
|
||||
pipelineId: 'pipeline-1',
|
||||
include: false,
|
||||
})
|
||||
expect(mockDownloadBlob).toHaveBeenCalledWith(expect.objectContaining({
|
||||
fileName: 'Dataset Name.pipeline',
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('opens the rename modal and refreshes dataset caches after a successful rename', async () => {
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
openDropdown()
|
||||
fireEvent.click(await screen.findByText('common.operation.edit'))
|
||||
|
||||
expect(screen.getByTestId('rename-dataset-modal')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'rename-success' }))
|
||||
|
||||
expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1)
|
||||
expect(mockInvalidDatasetDetail).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('checks app usage before deleting and redirects back to the dataset list after confirmation', async () => {
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
openDropdown()
|
||||
fireEvent.click(await screen.findByText('common.operation.delete'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockCheckIsUsedInApp).toHaveBeenCalledWith('dataset-1')
|
||||
expect(screen.getByText('dataset.deleteDatasetConfirmTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1')
|
||||
expect(mockInvalidDatasetList).toHaveBeenCalled()
|
||||
expect(mockReplace).toHaveBeenCalledWith('/datasets')
|
||||
})
|
||||
})
|
||||
})
|
||||
199
web/__tests__/app-sidebar/sidebar-shell-flow.test.tsx
Normal file
199
web/__tests__/app-sidebar/sidebar-shell-flow.test.tsx
Normal file
@ -0,0 +1,199 @@
|
||||
import type { SVGProps } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import AppDetailNav from '@/app/components/app-sidebar'
|
||||
|
||||
const mockSetAppSidebarExpand = vi.fn()
|
||||
|
||||
let mockAppSidebarExpand = 'expand'
|
||||
let mockPathname = '/app/app-1/logs'
|
||||
let mockSelectedSegment = 'logs'
|
||||
let mockIsHovering = true
|
||||
let keyPressHandler: ((event: { preventDefault: () => void }) => void) | null = null
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
appDetail: {
|
||||
id: 'app-1',
|
||||
name: 'Demo App',
|
||||
mode: 'chat',
|
||||
icon: '🤖',
|
||||
icon_type: 'emoji',
|
||||
icon_background: '#FFEAD5',
|
||||
icon_url: null,
|
||||
},
|
||||
appSidebarExpand: mockAppSidebarExpand,
|
||||
setAppSidebarExpand: mockSetAppSidebarExpand,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('zustand/react/shallow', () => ({
|
||||
useShallow: (selector: unknown) => selector,
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
usePathname: () => mockPathname,
|
||||
useSelectedLayoutSegment: () => mockSelectedSegment,
|
||||
}))
|
||||
|
||||
vi.mock('@/next/link', () => ({
|
||||
default: ({
|
||||
href,
|
||||
children,
|
||||
className,
|
||||
title,
|
||||
}: {
|
||||
href: string
|
||||
children?: React.ReactNode
|
||||
className?: string
|
||||
title?: string
|
||||
}) => (
|
||||
<a href={href} className={className} title={title}>
|
||||
{children}
|
||||
</a>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useHover: () => mockIsHovering,
|
||||
useKeyPress: (_key: string, handler: (event: { preventDefault: () => void }) => void) => {
|
||||
keyPressHandler = handler
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
default: () => 'desktop',
|
||||
MediaType: {
|
||||
mobile: 'mobile',
|
||||
desktop: 'desktop',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
eventEmitter: {
|
||||
useSubscription: vi.fn(),
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils', () => ({
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
getKeyboardKeyNameBySystem: (key: string) => key,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
|
||||
const React = await vi.importActual<typeof import('react')>('react')
|
||||
const OpenContext = React.createContext(false)
|
||||
|
||||
return {
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
|
||||
<OpenContext.Provider value={open}>
|
||||
<div>{children}</div>
|
||||
</OpenContext.Provider>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
onClick?: () => void
|
||||
}) => (
|
||||
<button type="button" data-testid="portal-trigger" onClick={onClick}>
|
||||
{children}
|
||||
</button>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
|
||||
const open = React.useContext(OpenContext)
|
||||
return open ? <div>{children}</div> : null
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app-sidebar/app-info', () => ({
|
||||
default: ({
|
||||
expand,
|
||||
onlyShowDetail,
|
||||
openState,
|
||||
}: {
|
||||
expand: boolean
|
||||
onlyShowDetail?: boolean
|
||||
openState?: boolean
|
||||
}) => (
|
||||
<div
|
||||
data-testid={onlyShowDetail ? 'app-info-detail' : 'app-info'}
|
||||
data-expand={expand}
|
||||
data-open={openState}
|
||||
/>
|
||||
),
|
||||
}))
|
||||
|
||||
const MockIcon = (props: SVGProps<SVGSVGElement>) => <svg {...props} />
|
||||
|
||||
const navigation = [
|
||||
{ name: 'Overview', href: '/app/app-1/overview', icon: MockIcon, selectedIcon: MockIcon },
|
||||
{ name: 'Logs', href: '/app/app-1/logs', icon: MockIcon, selectedIcon: MockIcon },
|
||||
]
|
||||
|
||||
describe('App Sidebar Shell Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
localStorage.clear()
|
||||
mockAppSidebarExpand = 'expand'
|
||||
mockPathname = '/app/app-1/logs'
|
||||
mockSelectedSegment = 'logs'
|
||||
mockIsHovering = true
|
||||
keyPressHandler = null
|
||||
})
|
||||
|
||||
it('renders the expanded sidebar, marks the active nav item, and toggles collapse by click and shortcut', () => {
|
||||
render(<AppDetailNav navigation={navigation} />)
|
||||
|
||||
expect(screen.getByTestId('app-info')).toHaveAttribute('data-expand', 'true')
|
||||
|
||||
const logsLink = screen.getByRole('link', { name: /Logs/i })
|
||||
expect(logsLink.className).toContain('bg-components-menu-item-bg-active')
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse')
|
||||
|
||||
const preventDefault = vi.fn()
|
||||
keyPressHandler?.({ preventDefault })
|
||||
|
||||
expect(preventDefault).toHaveBeenCalled()
|
||||
expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse')
|
||||
})
|
||||
|
||||
it('switches to the workflow fullscreen dropdown shell and opens its navigation menu', () => {
|
||||
mockPathname = '/app/app-1/workflow'
|
||||
mockSelectedSegment = 'workflow'
|
||||
localStorage.setItem('workflow-canvas-maximize', 'true')
|
||||
|
||||
render(<AppDetailNav navigation={navigation} />)
|
||||
|
||||
expect(screen.queryByTestId('app-info')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
|
||||
expect(screen.getByText('Demo App')).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: /Overview/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: /Logs/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
139
web/__tests__/app/app-access-control-flow.test.tsx
Normal file
139
web/__tests__/app/app-access-control-flow.test.tsx
Normal file
@ -0,0 +1,139 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import AppPublisher from '@/app/components/app/app-publisher'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const mockFetchAppDetailDirect = vi.fn()
|
||||
const mockSetAppDetail = vi.fn()
|
||||
const mockRefetch = vi.fn()
|
||||
|
||||
let mockAppDetail: {
|
||||
id: string
|
||||
name: string
|
||||
mode: AppModeEnum
|
||||
access_mode: AccessMode
|
||||
description: string
|
||||
icon: string
|
||||
icon_type: string
|
||||
icon_background: string
|
||||
site: {
|
||||
app_base_url: string
|
||||
access_token: string
|
||||
}
|
||||
} | null = null
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
appDetail: mockAppDetail,
|
||||
setAppDetail: mockSetAppDetail,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
webapp_auth: {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
useFormatTimeFromNow: () => ({
|
||||
formatTimeFromNow: (value: number) => `ago:${value}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/access-control', () => ({
|
||||
useGetUserCanAccessApp: () => ({
|
||||
data: { result: true },
|
||||
isLoading: false,
|
||||
refetch: mockRefetch,
|
||||
}),
|
||||
useAppWhiteListSubjects: () => ({
|
||||
data: { groups: [], members: [] },
|
||||
isLoading: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/overview/embedded', () => ({
|
||||
default: () => null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/app-access-control', () => ({
|
||||
default: ({
|
||||
onConfirm,
|
||||
onClose,
|
||||
}: {
|
||||
onConfirm: () => Promise<void>
|
||||
onClose: () => void
|
||||
}) => (
|
||||
<div data-testid="access-control-modal">
|
||||
<button type="button" onClick={() => void onConfirm()}>confirm-access-control</button>
|
||||
<button type="button" onClick={onClose}>close-access-control</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('App Access Control Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAppDetail = {
|
||||
id: 'app-1',
|
||||
name: 'Demo App',
|
||||
mode: AppModeEnum.CHAT,
|
||||
access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
|
||||
description: 'Demo app description',
|
||||
icon: '🤖',
|
||||
icon_type: 'emoji',
|
||||
icon_background: '#FFEAD5',
|
||||
site: {
|
||||
app_base_url: 'https://example.com',
|
||||
access_token: 'token-1',
|
||||
},
|
||||
}
|
||||
mockFetchAppDetailDirect.mockResolvedValue({
|
||||
...mockAppDetail,
|
||||
access_mode: AccessMode.PUBLIC,
|
||||
})
|
||||
})
|
||||
|
||||
it('refreshes app detail after confirming access control updates', async () => {
|
||||
render(<AppPublisher publishedAt={1700000000} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.publish' }))
|
||||
fireEvent.click(screen.getByText('app.accessControlDialog.accessItems.specific'))
|
||||
|
||||
expect(screen.getByTestId('access-control-modal')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'confirm-access-control' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchAppDetailDirect).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' })
|
||||
expect(mockSetAppDetail).toHaveBeenCalledWith(expect.objectContaining({
|
||||
id: 'app-1',
|
||||
access_mode: AccessMode.PUBLIC,
|
||||
}))
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('access-control-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
243
web/__tests__/app/app-publisher-flow.test.tsx
Normal file
243
web/__tests__/app/app-publisher-flow.test.tsx
Normal file
@ -0,0 +1,243 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import AppPublisher from '@/app/components/app/app-publisher'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const mockTrackEvent = vi.fn()
|
||||
const mockRefetch = vi.fn()
|
||||
const mockFetchInstalledAppList = vi.fn()
|
||||
const mockFetchAppDetailDirect = vi.fn()
|
||||
const mockToastError = vi.fn()
|
||||
const mockOpenAsyncWindow = vi.fn()
|
||||
const mockSetAppDetail = vi.fn()
|
||||
|
||||
let mockAppDetail: {
|
||||
id: string
|
||||
name: string
|
||||
mode: AppModeEnum
|
||||
access_mode: AccessMode
|
||||
description: string
|
||||
icon: string
|
||||
icon_type: string
|
||||
icon_background: string
|
||||
site: {
|
||||
app_base_url: string
|
||||
access_token: string
|
||||
}
|
||||
} | null = null
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useKeyPress: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
appDetail: mockAppDetail,
|
||||
setAppDetail: mockSetAppDetail,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
webapp_auth: {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
useFormatTimeFromNow: () => ({
|
||||
formatTimeFromNow: (value: number) => `ago:${value}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => mockOpenAsyncWindow,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/access-control', () => ({
|
||||
useGetUserCanAccessApp: () => ({
|
||||
data: { result: true },
|
||||
isLoading: false,
|
||||
refetch: mockRefetch,
|
||||
}),
|
||||
useAppWhiteListSubjects: () => ({
|
||||
data: { groups: [], members: [] },
|
||||
isLoading: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchInstalledAppList: (...args: unknown[]) => mockFetchInstalledAppList(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
error: (...args: unknown[]) => mockToastError(...args),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/overview/embedded', () => ({
|
||||
default: ({ isShow, onClose }: { isShow: boolean, onClose: () => void }) => (
|
||||
isShow
|
||||
? (
|
||||
<div data-testid="embedded-modal">
|
||||
<button onClick={onClose}>close-embedded</button>
|
||||
</div>
|
||||
)
|
||||
: null
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/app-access-control', () => ({
|
||||
default: () => <div data-testid="app-access-control" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
|
||||
const React = await vi.importActual<typeof import('react')>('react')
|
||||
const OpenContext = React.createContext(false)
|
||||
|
||||
return {
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
|
||||
<OpenContext.Provider value={open}>
|
||||
<div>{children}</div>
|
||||
</OpenContext.Provider>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
onClick?: () => void
|
||||
}) => (
|
||||
<div onClick={onClick}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
|
||||
const open = React.useContext(OpenContext)
|
||||
return open ? <div>{children}</div> : null
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/workflow/utils', () => ({
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
getKeyboardKeyNameBySystem: (key: string) => key,
|
||||
}))
|
||||
|
||||
describe('App Publisher Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAppDetail = {
|
||||
id: 'app-1',
|
||||
name: 'Demo App',
|
||||
mode: AppModeEnum.CHAT,
|
||||
access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
|
||||
description: 'Demo app description',
|
||||
icon: '🤖',
|
||||
icon_type: 'emoji',
|
||||
icon_background: '#FFEAD5',
|
||||
site: {
|
||||
app_base_url: 'https://example.com',
|
||||
access_token: 'token-1',
|
||||
},
|
||||
}
|
||||
mockFetchInstalledAppList.mockResolvedValue({
|
||||
installed_apps: [{ id: 'installed-1' }],
|
||||
})
|
||||
mockFetchAppDetailDirect.mockResolvedValue({
|
||||
id: 'app-1',
|
||||
access_mode: AccessMode.PUBLIC,
|
||||
})
|
||||
mockOpenAsyncWindow.mockImplementation(async (
|
||||
resolver: () => Promise<string>,
|
||||
options?: { onError?: (error: Error) => void },
|
||||
) => {
|
||||
try {
|
||||
return await resolver()
|
||||
}
|
||||
catch (error) {
|
||||
options?.onError?.(error as Error)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('publishes from the summary panel and tracks the publish event', async () => {
|
||||
const onPublish = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
render(
|
||||
<AppPublisher
|
||||
publishedAt={1700000000}
|
||||
onPublish={onPublish}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
|
||||
expect(screen.getByText('common.latestPublished')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.publishUpdate')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('common.publishUpdate'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onPublish).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('app_published_time', expect.objectContaining({
|
||||
action_mode: 'app',
|
||||
app_id: 'app-1',
|
||||
app_name: 'Demo App',
|
||||
}))
|
||||
})
|
||||
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('opens embedded modal and resolves the installed explore target', async () => {
|
||||
render(<AppPublisher publishedAt={1700000000} />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
fireEvent.click(screen.getByText('common.embedIntoSite'))
|
||||
|
||||
expect(screen.getByTestId('embedded-modal')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
fireEvent.click(screen.getByText('common.openInExplore'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchInstalledAppList).toHaveBeenCalledWith('app-1')
|
||||
expect(mockOpenAsyncWindow).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('shows a toast error when no installed explore app is available', async () => {
|
||||
mockFetchInstalledAppList.mockResolvedValue({
|
||||
installed_apps: [],
|
||||
})
|
||||
|
||||
render(<AppPublisher publishedAt={1700000000} />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.publish'))
|
||||
fireEvent.click(screen.getByText('common.openInExplore'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastError).toHaveBeenCalledWith('No app found in Explore')
|
||||
})
|
||||
})
|
||||
})
|
||||
154
web/__tests__/base/chat-flow.test.tsx
Normal file
154
web/__tests__/base/chat-flow.test.tsx
Normal file
@ -0,0 +1,154 @@
|
||||
import type { RefObject } from 'react'
|
||||
import type { ChatConfig } from '@/app/components/base/chat/types'
|
||||
import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share'
|
||||
import { fireEvent, render, renderHook, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import ChatWithHistory from '@/app/components/base/chat/chat-with-history'
|
||||
import { useChatWithHistory } from '@/app/components/base/chat/chat-with-history/hooks'
|
||||
import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
|
||||
vi.mock('@/app/components/base/chat/chat-with-history/hooks', () => ({
|
||||
useChatWithHistory: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
default: vi.fn(),
|
||||
MediaType: {
|
||||
mobile: 'mobile',
|
||||
tablet: 'tablet',
|
||||
pc: 'pc',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-document-title', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: vi.fn(() => ({
|
||||
push: vi.fn(),
|
||||
replace: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
})),
|
||||
usePathname: vi.fn(() => '/'),
|
||||
useSearchParams: vi.fn(() => new URLSearchParams()),
|
||||
useParams: vi.fn(() => ({})),
|
||||
}))
|
||||
|
||||
type HookReturn = ReturnType<typeof useChatWithHistory>
|
||||
|
||||
const mockAppData = {
|
||||
site: { title: 'Test Chat', chat_color_theme: 'blue', chat_color_theme_inverted: false },
|
||||
} as unknown as AppData
|
||||
|
||||
const defaultHookReturn: HookReturn = {
|
||||
isInstalledApp: false,
|
||||
appId: 'test-app-id',
|
||||
currentConversationId: '',
|
||||
currentConversationItem: undefined,
|
||||
handleConversationIdInfoChange: vi.fn(),
|
||||
appData: mockAppData,
|
||||
appParams: {} as ChatConfig,
|
||||
appMeta: {} as AppMeta,
|
||||
appPinnedConversationData: { data: [] as ConversationItem[], has_more: false, limit: 20 } as AppConversationData,
|
||||
appConversationData: { data: [] as ConversationItem[], has_more: false, limit: 20 } as AppConversationData,
|
||||
appConversationDataLoading: false,
|
||||
appChatListData: { data: [] as ConversationItem[], has_more: false, limit: 20 } as AppConversationData,
|
||||
appChatListDataLoading: false,
|
||||
appPrevChatTree: [],
|
||||
pinnedConversationList: [],
|
||||
conversationList: [],
|
||||
setShowNewConversationItemInList: vi.fn(),
|
||||
newConversationInputs: {},
|
||||
newConversationInputsRef: { current: {} } as unknown as RefObject<Record<string, unknown>>,
|
||||
handleNewConversationInputsChange: vi.fn(),
|
||||
inputsForms: [],
|
||||
handleNewConversation: vi.fn(),
|
||||
handleStartChat: vi.fn(),
|
||||
handleChangeConversation: vi.fn(),
|
||||
handlePinConversation: vi.fn(),
|
||||
handleUnpinConversation: vi.fn(),
|
||||
conversationDeleting: false,
|
||||
handleDeleteConversation: vi.fn(),
|
||||
conversationRenaming: false,
|
||||
handleRenameConversation: vi.fn(),
|
||||
handleNewConversationCompleted: vi.fn(),
|
||||
newConversationId: '',
|
||||
chatShouldReloadKey: 'test-reload-key',
|
||||
handleFeedback: vi.fn(),
|
||||
currentChatInstanceRef: { current: { handleStop: vi.fn() } },
|
||||
sidebarCollapseState: false,
|
||||
handleSidebarCollapse: vi.fn(),
|
||||
clearChatList: false,
|
||||
setClearChatList: vi.fn(),
|
||||
isResponding: false,
|
||||
setIsResponding: vi.fn(),
|
||||
currentConversationInputs: {},
|
||||
setCurrentConversationInputs: vi.fn(),
|
||||
allInputsHidden: false,
|
||||
initUserVariables: {},
|
||||
}
|
||||
|
||||
describe('Base Chat Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(useBreakpoints).mockReturnValue(MediaType.pc)
|
||||
vi.mocked(useChatWithHistory).mockReturnValue(defaultHookReturn)
|
||||
renderHook(() => useThemeContext()).result.current.buildTheme()
|
||||
})
|
||||
|
||||
// Chat-with-history shell integration across layout, responsive shell, and theme setup.
|
||||
describe('Chat With History Shell', () => {
|
||||
it('builds theme, updates the document title, and expands the collapsed desktop sidebar on hover', async () => {
|
||||
const themeBuilder = renderHook(() => useThemeContext()).result.current
|
||||
const { container } = render(<ChatWithHistory className="chat-history-shell" />)
|
||||
|
||||
const titles = screen.getAllByText('Test Chat')
|
||||
expect(titles.length).toBeGreaterThan(0)
|
||||
expect(useDocumentTitle).toHaveBeenCalledWith('Test Chat')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(themeBuilder.theme.primaryColor).toBe('blue')
|
||||
expect(themeBuilder.theme.chatColorThemeInverted).toBe(false)
|
||||
})
|
||||
|
||||
vi.mocked(useChatWithHistory).mockReturnValue({
|
||||
...defaultHookReturn,
|
||||
sidebarCollapseState: true,
|
||||
})
|
||||
|
||||
const { container: collapsedContainer } = render(<ChatWithHistory />)
|
||||
const hoverArea = collapsedContainer.querySelector('.absolute.top-0.z-20')
|
||||
|
||||
expect(container.querySelector('.chat-history-shell')).toBeInTheDocument()
|
||||
expect(hoverArea).toBeInTheDocument()
|
||||
|
||||
if (hoverArea) {
|
||||
fireEvent.mouseEnter(hoverArea)
|
||||
expect(hoverArea).toHaveClass('left-0')
|
||||
|
||||
fireEvent.mouseLeave(hoverArea)
|
||||
expect(hoverArea).toHaveClass('left-[-248px]')
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to the mobile loading shell when site metadata is unavailable', () => {
|
||||
vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile)
|
||||
vi.mocked(useChatWithHistory).mockReturnValue({
|
||||
...defaultHookReturn,
|
||||
appData: null,
|
||||
appChatListDataLoading: true,
|
||||
})
|
||||
|
||||
const { container } = render(<ChatWithHistory className="mobile-chat-shell" />)
|
||||
|
||||
expect(useDocumentTitle).toHaveBeenCalledWith('Chat')
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
expect(container.querySelector('.mobile-chat-shell')).toBeInTheDocument()
|
||||
expect(container.querySelector('.rounded-t-2xl')).toBeInTheDocument()
|
||||
expect(container.querySelector('.rounded-2xl')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
106
web/__tests__/base/file-uploader-flow.test.tsx
Normal file
106
web/__tests__/base/file-uploader-flow.test.tsx
Normal file
@ -0,0 +1,106 @@
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import FileUploaderInAttachmentWrapper from '@/app/components/base/file-uploader/file-uploader-in-attachment'
|
||||
import FileUploaderInChatInput from '@/app/components/base/file-uploader/file-uploader-in-chat-input'
|
||||
import { FileContextProvider } from '@/app/components/base/file-uploader/store'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
|
||||
const mockUploadRemoteFileInfo = vi.fn()
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useParams: () => ({}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
uploadRemoteFileInfo: (...args: unknown[]) => mockUploadRemoteFileInfo(...args),
|
||||
}))
|
||||
|
||||
const createFileConfig = (overrides: Partial<FileUpload> = {}): FileUpload => ({
|
||||
enabled: true,
|
||||
allowed_file_types: ['document'],
|
||||
allowed_file_extensions: [],
|
||||
allowed_file_upload_methods: [TransferMethod.remote_url],
|
||||
number_limits: 5,
|
||||
preview_config: {
|
||||
enabled: false,
|
||||
mode: 'current_page',
|
||||
file_type_list: [],
|
||||
},
|
||||
...overrides,
|
||||
} as FileUpload)
|
||||
|
||||
const renderChatInput = (fileConfig: FileUpload, readonly = false) => {
|
||||
return render(
|
||||
<FileContextProvider>
|
||||
<FileUploaderInChatInput fileConfig={fileConfig} readonly={readonly} />
|
||||
</FileContextProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Base File Uploader Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUploadRemoteFileInfo.mockResolvedValue({
|
||||
id: 'remote-file-1',
|
||||
mime_type: 'application/pdf',
|
||||
size: 2048,
|
||||
name: 'guide.pdf',
|
||||
url: 'https://cdn.example.com/guide.pdf',
|
||||
})
|
||||
})
|
||||
|
||||
it('uploads a remote file from the attachment wrapper and pushes the updated file list to consumers', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChange = vi.fn()
|
||||
|
||||
render(
|
||||
<FileUploaderInAttachmentWrapper
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
fileConfig={createFileConfig()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /fileUploader\.pasteFileLink/i }))
|
||||
await user.type(screen.getByPlaceholderText(/fileUploader\.pasteFileLinkInputPlaceholder/i), 'https://example.com/guide.pdf')
|
||||
await user.click(screen.getByRole('button', { name: /operation\.ok/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUploadRemoteFileInfo).toHaveBeenCalledWith('https://example.com/guide.pdf', false)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onChange).toHaveBeenLastCalledWith([
|
||||
expect.objectContaining({
|
||||
name: 'https://example.com/guide.pdf',
|
||||
uploadedId: 'remote-file-1',
|
||||
url: 'https://cdn.example.com/guide.pdf',
|
||||
progress: 100,
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
expect(screen.getByText('https://example.com/guide.pdf')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens the link picker from chat input and keeps the trigger disabled in readonly mode', async () => {
|
||||
const user = userEvent.setup()
|
||||
const fileConfig = createFileConfig()
|
||||
|
||||
const { unmount } = renderChatInput(fileConfig)
|
||||
|
||||
const activeTrigger = screen.getByRole('button')
|
||||
expect(activeTrigger).toBeEnabled()
|
||||
|
||||
await user.click(activeTrigger)
|
||||
expect(screen.getByPlaceholderText(/fileUploader\.pasteFileLinkInputPlaceholder/i)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/fileUploader\.uploadFromComputer/i)).not.toBeInTheDocument()
|
||||
|
||||
unmount()
|
||||
renderChatInput(fileConfig, true)
|
||||
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
65
web/__tests__/base/form-demo-flow.test.tsx
Normal file
65
web/__tests__/base/form-demo-flow.test.tsx
Normal file
@ -0,0 +1,65 @@
|
||||
import { render, screen, waitFor, within } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import DemoForm from '@/app/components/base/form/form-scenarios/demo'
|
||||
|
||||
describe('Base Form Demo Flow', () => {
|
||||
const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('reveals contact fields and submits the composed form values through the shared form actions', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<DemoForm />)
|
||||
|
||||
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
|
||||
|
||||
await user.type(screen.getByRole('textbox', { name: /^name$/i }), 'Alice')
|
||||
await user.type(screen.getByRole('textbox', { name: /^surname$/i }), 'Smith')
|
||||
await user.click(screen.getByText(/i accept the terms and conditions/i))
|
||||
|
||||
expect(await screen.findByRole('heading', { name: /contacts/i })).toBeInTheDocument()
|
||||
|
||||
await user.type(screen.getByRole('textbox', { name: /^email$/i }), 'alice@example.com')
|
||||
|
||||
const preferredMethodLabel = screen.getByText('Preferred Contact Method')
|
||||
const preferredMethodField = preferredMethodLabel.parentElement?.parentElement
|
||||
expect(preferredMethodField).toBeTruthy()
|
||||
|
||||
await user.click(within(preferredMethodField as HTMLElement).getByText('Email'))
|
||||
await user.click(screen.getByText('Whatsapp'))
|
||||
|
||||
const submitButton = screen.getByRole('button', { name: /operation\.submit/i })
|
||||
expect(submitButton).toBeEnabled()
|
||||
await user.click(submitButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleLogSpy).toHaveBeenCalledWith('Form submitted:', expect.objectContaining({
|
||||
name: 'Alice',
|
||||
surname: 'Smith',
|
||||
isAcceptingTerms: true,
|
||||
contact: expect.objectContaining({
|
||||
email: 'alice@example.com',
|
||||
preferredContactMethod: 'whatsapp',
|
||||
}),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('removes the nested contact section again when the name field is cleared', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<DemoForm />)
|
||||
|
||||
const nameInput = screen.getByRole('textbox', { name: /^name$/i })
|
||||
await user.type(nameInput, 'Alice')
|
||||
expect(await screen.findByRole('heading', { name: /contacts/i })).toBeInTheDocument()
|
||||
|
||||
await user.clear(nameInput)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
151
web/__tests__/base/notion-page-selector-flow.test.tsx
Normal file
151
web/__tests__/base/notion-page-selector-flow.test.tsx
Normal file
@ -0,0 +1,151 @@
|
||||
import type { DataSourceCredential } from '@/app/components/header/account-setting/data-source-page-new/types'
|
||||
import type { DataSourceNotionWorkspace } from '@/models/common'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import NotionPageSelector from '@/app/components/base/notion-page-selector/base'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { CredentialTypeEnum } from '@/app/components/plugins/plugin-auth/types'
|
||||
|
||||
const mockInvalidPreImportNotionPages = vi.fn()
|
||||
const mockSetShowAccountSettingModal = vi.fn()
|
||||
const mockUsePreImportNotionPages = vi.fn()
|
||||
|
||||
vi.mock('@tanstack/react-virtual', () => ({
|
||||
useVirtualizer: ({ count }: { count: number }) => ({
|
||||
getVirtualItems: () => Array.from({ length: count }, (_, index) => ({
|
||||
index,
|
||||
size: 28,
|
||||
start: index * 28,
|
||||
})),
|
||||
getTotalSize: () => count * 28 + 16,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/knowledge/use-import', () => ({
|
||||
usePreImportNotionPages: (params: { datasetId: string, credentialId: string }) => mockUsePreImportNotionPages(params),
|
||||
useInvalidPreImportNotionPages: () => mockInvalidPreImportNotionPages,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContextSelector: (selector: (state: { setShowAccountSettingModal: typeof mockSetShowAccountSettingModal }) => unknown) =>
|
||||
selector({ setShowAccountSettingModal: mockSetShowAccountSettingModal }),
|
||||
}))
|
||||
|
||||
const buildCredential = (id: string, name: string, workspaceName: string): DataSourceCredential => ({
|
||||
id,
|
||||
name,
|
||||
type: CredentialTypeEnum.OAUTH2,
|
||||
is_default: false,
|
||||
avatar_url: '',
|
||||
credential: {
|
||||
workspace_icon: '',
|
||||
workspace_name: workspaceName,
|
||||
},
|
||||
})
|
||||
|
||||
const credentials: DataSourceCredential[] = [
|
||||
buildCredential('c1', 'Cred 1', 'Workspace 1'),
|
||||
buildCredential('c2', 'Cred 2', 'Workspace 2'),
|
||||
]
|
||||
|
||||
const workspacePagesByCredential: Record<string, DataSourceNotionWorkspace[]> = {
|
||||
c1: [
|
||||
{
|
||||
workspace_id: 'w1',
|
||||
workspace_icon: '',
|
||||
workspace_name: 'Workspace 1',
|
||||
pages: [
|
||||
{ page_id: 'root-1', page_name: 'Root 1', parent_id: 'root', page_icon: null, type: 'page', is_bound: false },
|
||||
{ page_id: 'child-1', page_name: 'Child 1', parent_id: 'root-1', page_icon: null, type: 'page', is_bound: false },
|
||||
{ page_id: 'bound-1', page_name: 'Bound 1', parent_id: 'root', page_icon: null, type: 'page', is_bound: true },
|
||||
],
|
||||
},
|
||||
],
|
||||
c2: [
|
||||
{
|
||||
workspace_id: 'w2',
|
||||
workspace_icon: '',
|
||||
workspace_name: 'Workspace 2',
|
||||
pages: [
|
||||
{ page_id: 'external-1', page_name: 'External 1', parent_id: 'root', page_icon: null, type: 'page', is_bound: false },
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
describe('Base Notion Page Selector Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUsePreImportNotionPages.mockImplementation(({ credentialId }: { credentialId: string }) => ({
|
||||
data: {
|
||||
notion_info: workspacePagesByCredential[credentialId] ?? workspacePagesByCredential.c1,
|
||||
},
|
||||
isFetching: false,
|
||||
isError: false,
|
||||
}))
|
||||
})
|
||||
|
||||
it('selects a page tree, filters through search, clears search, and previews the current page', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
const onPreview = vi.fn()
|
||||
|
||||
render(
|
||||
<NotionPageSelector
|
||||
credentialList={credentials}
|
||||
onSelect={onSelect}
|
||||
onPreview={onPreview}
|
||||
previewPageId="root-1"
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('checkbox-notion-page-checkbox-root-1'))
|
||||
|
||||
expect(onSelect).toHaveBeenLastCalledWith(expect.arrayContaining([
|
||||
expect.objectContaining({ page_id: 'root-1', workspace_id: 'w1' }),
|
||||
expect.objectContaining({ page_id: 'child-1', workspace_id: 'w1' }),
|
||||
expect.objectContaining({ page_id: 'bound-1', workspace_id: 'w1' }),
|
||||
]))
|
||||
|
||||
await user.type(screen.getByTestId('notion-search-input'), 'missing-page')
|
||||
expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByTestId('notion-search-input-clear'))
|
||||
expect(screen.getByTestId('notion-page-name-root-1')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByTestId('notion-page-preview-root-1'))
|
||||
expect(onPreview).toHaveBeenCalledWith(expect.objectContaining({ page_id: 'root-1', workspace_id: 'w1' }))
|
||||
})
|
||||
|
||||
it('switches workspace credentials and opens the configuration entry point', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
const onSelectCredential = vi.fn()
|
||||
|
||||
render(
|
||||
<NotionPageSelector
|
||||
credentialList={credentials}
|
||||
onSelect={onSelect}
|
||||
onSelectCredential={onSelectCredential}
|
||||
datasetId="dataset-1"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(onSelectCredential).toHaveBeenCalledWith('c1')
|
||||
|
||||
await user.click(screen.getByTestId('notion-credential-selector-btn'))
|
||||
await user.click(screen.getByTestId('notion-credential-item-c2'))
|
||||
|
||||
expect(mockInvalidPreImportNotionPages).toHaveBeenCalledWith({ datasetId: 'dataset-1', credentialId: 'c2' })
|
||||
expect(onSelect).toHaveBeenCalledWith([])
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSelectCredential).toHaveBeenLastCalledWith('c2')
|
||||
expect(screen.getByTestId('notion-page-name-external-1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'common.dataSource.notion.selector.configure' }))
|
||||
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.DATA_SOURCE })
|
||||
})
|
||||
})
|
||||
191
web/__tests__/base/prompt-editor-flow.test.tsx
Normal file
191
web/__tests__/base/prompt-editor-flow.test.tsx
Normal file
@ -0,0 +1,191 @@
|
||||
import type { EventEmitter } from 'ahooks/lib/useEventEmitter'
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { EventEmitterValue } from '@/context/event-emitter'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { getNearestEditorFromDOMNode } from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
import PromptEditor from '@/app/components/base/prompt-editor'
|
||||
import {
|
||||
UPDATE_DATASETS_EVENT_EMITTER,
|
||||
UPDATE_HISTORY_EVENT_EMITTER,
|
||||
} from '@/app/components/base/prompt-editor/constants'
|
||||
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter-provider'
|
||||
|
||||
type Captures = {
|
||||
eventEmitter: EventEmitter<EventEmitterValue> | null
|
||||
events: EventEmitterValue[]
|
||||
}
|
||||
|
||||
const EventProbe = ({ captures }: { captures: Captures }) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
|
||||
useEffect(() => {
|
||||
captures.eventEmitter = eventEmitter
|
||||
}, [captures, eventEmitter])
|
||||
|
||||
eventEmitter?.useSubscription((value) => {
|
||||
captures.events.push(value)
|
||||
})
|
||||
|
||||
return <button type="button">outside</button>
|
||||
}
|
||||
|
||||
const PromptEditorHarness = ({
|
||||
captures,
|
||||
...props
|
||||
}: ComponentProps<typeof PromptEditor> & { captures: Captures }) => (
|
||||
<EventEmitterContextProvider>
|
||||
<EventProbe captures={captures} />
|
||||
<PromptEditor {...props} />
|
||||
</EventEmitterContextProvider>
|
||||
)
|
||||
|
||||
describe('Base Prompt Editor Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Real prompt editor integration should emit block updates and transform editor updates into text output.
|
||||
describe('Editor Shell', () => {
|
||||
it('should render with the real editor, emit dataset/history events, and convert update events into text changes', async () => {
|
||||
const captures: Captures = { eventEmitter: null, events: [] }
|
||||
const onChange = vi.fn()
|
||||
const onFocus = vi.fn()
|
||||
const onBlur = vi.fn()
|
||||
const user = userEvent.setup()
|
||||
|
||||
const { rerender, container } = render(
|
||||
<PromptEditorHarness
|
||||
captures={captures}
|
||||
instanceId="editor-1"
|
||||
compact={true}
|
||||
className="editor-shell"
|
||||
placeholder="Type prompt"
|
||||
onChange={onChange}
|
||||
onFocus={onFocus}
|
||||
onBlur={onBlur}
|
||||
contextBlock={{
|
||||
show: false,
|
||||
datasets: [{ id: 'ds-1', name: 'Dataset One', type: 'dataset' }],
|
||||
}}
|
||||
historyBlock={{
|
||||
show: false,
|
||||
history: { user: 'user-role', assistant: 'assistant-role' },
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Type prompt')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.eventEmitter).not.toBeNull()
|
||||
})
|
||||
|
||||
const editable = container.querySelector('[contenteditable="true"]') as HTMLElement
|
||||
expect(editable).toBeInTheDocument()
|
||||
|
||||
await user.click(editable)
|
||||
await waitFor(() => {
|
||||
expect(onFocus).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'outside' }))
|
||||
await waitFor(() => {
|
||||
expect(onBlur).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
act(() => {
|
||||
captures.eventEmitter?.emit({
|
||||
type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
|
||||
instanceId: 'editor-1',
|
||||
payload: 'first line\nsecond line',
|
||||
})
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onChange).toHaveBeenCalledWith('first line\nsecond line')
|
||||
})
|
||||
|
||||
expect(captures.events).toContainEqual({
|
||||
type: UPDATE_DATASETS_EVENT_EMITTER,
|
||||
payload: [{ id: 'ds-1', name: 'Dataset One', type: 'dataset' }],
|
||||
})
|
||||
expect(captures.events).toContainEqual({
|
||||
type: UPDATE_HISTORY_EVENT_EMITTER,
|
||||
payload: { user: 'user-role', assistant: 'assistant-role' },
|
||||
})
|
||||
|
||||
rerender(
|
||||
<PromptEditorHarness
|
||||
captures={captures}
|
||||
instanceId="editor-1"
|
||||
contextBlock={{
|
||||
show: false,
|
||||
datasets: [{ id: 'ds-2', name: 'Dataset Two', type: 'dataset' }],
|
||||
}}
|
||||
historyBlock={{
|
||||
show: false,
|
||||
history: { user: 'user-next', assistant: 'assistant-next' },
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.events).toContainEqual({
|
||||
type: UPDATE_DATASETS_EVENT_EMITTER,
|
||||
payload: [{ id: 'ds-2', name: 'Dataset Two', type: 'dataset' }],
|
||||
})
|
||||
})
|
||||
expect(captures.events).toContainEqual({
|
||||
type: UPDATE_HISTORY_EVENT_EMITTER,
|
||||
payload: { user: 'user-next', assistant: 'assistant-next' },
|
||||
})
|
||||
})
|
||||
|
||||
it('should tolerate updates without onChange and rethrow lexical runtime errors through the configured handler', async () => {
|
||||
const captures: Captures = { eventEmitter: null, events: [] }
|
||||
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { container } = render(
|
||||
<PromptEditorHarness
|
||||
captures={captures}
|
||||
instanceId="editor-2"
|
||||
editable={false}
|
||||
placeholder="Read only prompt"
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.eventEmitter).not.toBeNull()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
captures.eventEmitter?.emit({
|
||||
type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
|
||||
instanceId: 'editor-2',
|
||||
payload: 'silent update',
|
||||
})
|
||||
})
|
||||
|
||||
const editable = container.querySelector('[contenteditable="false"]') as HTMLElement
|
||||
const editor = getNearestEditorFromDOMNode(editable)
|
||||
|
||||
expect(editable).toBeInTheDocument()
|
||||
expect(editor).not.toBeNull()
|
||||
expect(screen.getByRole('textbox')).toHaveTextContent('silent update')
|
||||
|
||||
expect(() => {
|
||||
act(() => {
|
||||
editor?.update(() => {
|
||||
throw new Error('prompt-editor boom')
|
||||
})
|
||||
})
|
||||
}).toThrow('prompt-editor boom')
|
||||
|
||||
consoleErrorSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
107
web/__tests__/custom/custom-page-flow.test.tsx
Normal file
107
web/__tests__/custom/custom-page-flow.test.tsx
Normal file
@ -0,0 +1,107 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createMockProviderContextValue } from '@/__mocks__/provider-context'
|
||||
import { contactSalesUrl, defaultPlan } from '@/app/components/billing/config'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import CustomPage from '@/app/components/custom/custom-page'
|
||||
import useWebAppBrand from '@/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand'
|
||||
|
||||
const mockSetShowPricingModal = vi.fn()
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
setShowPricingModal: mockSetShowPricingModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
const { useProviderContext } = await import('@/context/provider-context')
|
||||
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseWebAppBrand = vi.mocked(useWebAppBrand)
|
||||
|
||||
const createBrandState = (overrides: Partial<ReturnType<typeof useWebAppBrand>> = {}): ReturnType<typeof useWebAppBrand> => ({
|
||||
fileId: '',
|
||||
imgKey: 1,
|
||||
uploadProgress: 0,
|
||||
uploading: false,
|
||||
webappLogo: 'https://example.com/logo.png',
|
||||
webappBrandRemoved: false,
|
||||
uploadDisabled: false,
|
||||
workspaceLogo: 'https://example.com/workspace-logo.png',
|
||||
isCurrentWorkspaceManager: true,
|
||||
isSandbox: false,
|
||||
handleApply: vi.fn(),
|
||||
handleCancel: vi.fn(),
|
||||
handleChange: vi.fn(),
|
||||
handleRestore: vi.fn(),
|
||||
handleSwitch: vi.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const setProviderPlan = (planType: Plan, enableBilling = true) => {
|
||||
mockUseProviderContext.mockReturnValue(createMockProviderContextValue({
|
||||
enableBilling,
|
||||
plan: {
|
||||
...defaultPlan,
|
||||
type: planType,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
describe('Custom Page Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setProviderPlan(Plan.professional)
|
||||
mockUseWebAppBrand.mockReturnValue(createBrandState())
|
||||
})
|
||||
|
||||
it('shows the billing upgrade banner for sandbox workspaces and opens pricing modal', () => {
|
||||
setProviderPlan(Plan.sandbox)
|
||||
|
||||
render(<CustomPage />)
|
||||
|
||||
expect(screen.getByText('custom.upgradeTip.title')).toBeInTheDocument()
|
||||
expect(screen.queryByText('custom.customize.contactUs')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('billing.upgradeBtn.encourageShort'))
|
||||
|
||||
expect(mockSetShowPricingModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('renders the branding controls and the sales contact footer for paid workspaces', () => {
|
||||
const hookState = createBrandState({
|
||||
fileId: 'pending-logo',
|
||||
})
|
||||
mockUseWebAppBrand.mockReturnValue(hookState)
|
||||
|
||||
render(<CustomPage />)
|
||||
|
||||
const contactLink = screen.getByText('custom.customize.contactUs').closest('a')
|
||||
expect(contactLink).toHaveAttribute('href', contactSalesUrl)
|
||||
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'custom.restore' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'custom.apply' }))
|
||||
|
||||
expect(hookState.handleSwitch).toHaveBeenCalledWith(true)
|
||||
expect(hookState.handleRestore).toHaveBeenCalledTimes(1)
|
||||
expect(hookState.handleCancel).toHaveBeenCalledTimes(1)
|
||||
expect(hookState.handleApply).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
182
web/__tests__/header/account-dropdown-flow.test.tsx
Normal file
182
web/__tests__/header/account-dropdown-flow.test.tsx
Normal file
@ -0,0 +1,182 @@
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import AccountDropdown from '@/app/components/header/account-dropdown'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
|
||||
const {
|
||||
mockPush,
|
||||
mockLogout,
|
||||
mockResetUser,
|
||||
mockSetShowAccountSettingModal,
|
||||
} = vi.hoisted(() => ({
|
||||
mockPush: vi.fn(),
|
||||
mockLogout: vi.fn(),
|
||||
mockResetUser: vi.fn(),
|
||||
mockSetShowAccountSettingModal: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string, version?: string }) => {
|
||||
if (options?.version)
|
||||
return `${options.ns}.${key}:${options.version}`
|
||||
return options?.ns ? `${options.ns}.${key}` : key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
userProfile: {
|
||||
name: 'Ada Lovelace',
|
||||
email: 'ada@example.com',
|
||||
avatar_url: '',
|
||||
},
|
||||
langGeniusVersionInfo: {
|
||||
current_version: '1.0.0',
|
||||
latest_version: '1.1.0',
|
||||
release_notes: 'https://example.com/releases/1.1.0',
|
||||
},
|
||||
isCurrentWorkspaceOwner: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
isEducationAccount: false,
|
||||
plan: {
|
||||
type: Plan.professional,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector?: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
systemFeatures: {
|
||||
branding: {
|
||||
enabled: false,
|
||||
workspace_logo: null,
|
||||
},
|
||||
},
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
setShowAccountSettingModal: mockSetShowAccountSettingModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path: string) => `https://docs.example.com${path}`,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useLogout: () => ({
|
||||
mutateAsync: mockLogout,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude/utils', () => ({
|
||||
resetUser: mockResetUser,
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/link', () => ({
|
||||
default: ({
|
||||
href,
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
href: string
|
||||
children?: React.ReactNode
|
||||
} & Record<string, unknown>) => (
|
||||
<a href={href} {...props}>
|
||||
{children}
|
||||
</a>
|
||||
),
|
||||
}))
|
||||
|
||||
const renderAccountDropdown = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<AccountDropdown />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Header Account Dropdown Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValue(new Response(JSON.stringify({
|
||||
repo: { stars: 123456 },
|
||||
}), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}))
|
||||
localStorage.clear()
|
||||
})
|
||||
|
||||
it('opens account actions, fetches github stars, and opens the settings and about flows', async () => {
|
||||
renderAccountDropdown()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.account.account' }))
|
||||
|
||||
expect(screen.getByText('Ada Lovelace')).toBeInTheDocument()
|
||||
expect(screen.getByText('ada@example.com')).toBeInTheDocument()
|
||||
expect(await screen.findByText('123,456')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('common.userProfile.settings'))
|
||||
|
||||
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({
|
||||
payload: ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('common.userProfile.about'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/Version/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/1\.0\.0/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('logs out, resets cached user markers, and redirects to signin', async () => {
|
||||
localStorage.setItem('setup_status', 'done')
|
||||
localStorage.setItem('education-reverify-prev-expire-at', '1')
|
||||
localStorage.setItem('education-reverify-has-noticed', '1')
|
||||
localStorage.setItem('education-expired-has-noticed', '1')
|
||||
|
||||
renderAccountDropdown()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.account.account' }))
|
||||
fireEvent.click(screen.getByText('common.userProfile.logout'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockLogout).toHaveBeenCalledTimes(1)
|
||||
expect(mockResetUser).toHaveBeenCalledTimes(1)
|
||||
expect(mockPush).toHaveBeenCalledWith('/signin')
|
||||
})
|
||||
|
||||
expect(localStorage.getItem('setup_status')).toBeNull()
|
||||
expect(localStorage.getItem('education-reverify-prev-expire-at')).toBeNull()
|
||||
expect(localStorage.getItem('education-reverify-has-noticed')).toBeNull()
|
||||
expect(localStorage.getItem('education-expired-has-noticed')).toBeNull()
|
||||
})
|
||||
})
|
||||
237
web/__tests__/header/nav-flow.test.tsx
Normal file
237
web/__tests__/header/nav-flow.test.tsx
Normal file
@ -0,0 +1,237 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import Nav from '@/app/components/header/nav'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
const mockSetAppDetail = vi.fn()
|
||||
const mockOnCreate = vi.fn()
|
||||
const mockOnLoadMore = vi.fn()
|
||||
|
||||
let mockSelectedSegment = 'app'
|
||||
let mockIsCurrentWorkspaceEditor = true
|
||||
|
||||
vi.mock('@headlessui/react', () => {
|
||||
type MenuContextValue = {
|
||||
open: boolean
|
||||
setOpen: React.Dispatch<React.SetStateAction<boolean>>
|
||||
}
|
||||
const MenuContext = React.createContext<MenuContextValue | null>(null)
|
||||
|
||||
const Menu = ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode)
|
||||
}) => {
|
||||
const [open, setOpen] = React.useState(false)
|
||||
const value = React.useMemo(() => ({ open, setOpen }), [open])
|
||||
|
||||
return (
|
||||
<MenuContext.Provider value={value}>
|
||||
{typeof children === 'function' ? children({ open }) : children}
|
||||
</MenuContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
const MenuButton = ({
|
||||
children,
|
||||
onClick,
|
||||
...props
|
||||
}: React.ButtonHTMLAttributes<HTMLButtonElement>) => {
|
||||
const context = React.useContext(MenuContext)
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-expanded={context?.open ?? false}
|
||||
onClick={(event) => {
|
||||
context?.setOpen(v => !v)
|
||||
onClick?.(event)
|
||||
}}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
const MenuItems = ({
|
||||
as: Component = 'div',
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
as?: React.ElementType
|
||||
children: React.ReactNode
|
||||
} & Record<string, unknown>) => {
|
||||
const context = React.useContext(MenuContext)
|
||||
if (!context?.open)
|
||||
return null
|
||||
|
||||
return <Component {...props}>{children}</Component>
|
||||
}
|
||||
|
||||
const MenuItem = ({
|
||||
as: Component = 'div',
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
as?: React.ElementType
|
||||
children: React.ReactNode
|
||||
} & Record<string, unknown>) => <Component {...props}>{children}</Component>
|
||||
|
||||
return {
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItems,
|
||||
MenuItem,
|
||||
Transition: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useSelectedLayoutSegment: () => mockSelectedSegment,
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/link', () => ({
|
||||
default: ({
|
||||
href,
|
||||
children,
|
||||
}: {
|
||||
href: string
|
||||
children?: React.ReactNode
|
||||
}) => <a href={href}>{children}</a>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: () => mockSetAppDetail,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor,
|
||||
}),
|
||||
}))
|
||||
|
||||
const navigationItems = [
|
||||
{
|
||||
id: 'app-1',
|
||||
name: 'Alpha',
|
||||
link: '/app/app-1/configuration',
|
||||
icon_type: 'emoji' as const,
|
||||
icon: '🤖',
|
||||
icon_background: '#FFEAD5',
|
||||
icon_url: null,
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
{
|
||||
id: 'app-2',
|
||||
name: 'Bravo',
|
||||
link: '/app/app-2/workflow',
|
||||
icon_type: 'emoji' as const,
|
||||
icon: '⚙️',
|
||||
icon_background: '#E0F2FE',
|
||||
icon_url: null,
|
||||
mode: AppModeEnum.WORKFLOW,
|
||||
},
|
||||
]
|
||||
|
||||
const curNav = {
|
||||
id: 'app-1',
|
||||
name: 'Alpha',
|
||||
icon_type: 'emoji' as const,
|
||||
icon: '🤖',
|
||||
icon_background: '#FFEAD5',
|
||||
icon_url: null,
|
||||
mode: AppModeEnum.CHAT,
|
||||
}
|
||||
|
||||
const renderNav = (nav = curNav) => {
|
||||
return render(
|
||||
<Nav
|
||||
isApp
|
||||
icon={<span data-testid="nav-icon">icon</span>}
|
||||
activeIcon={<span data-testid="nav-icon-active">active-icon</span>}
|
||||
text="menus.apps"
|
||||
activeSegment={['apps', 'app']}
|
||||
link="/apps"
|
||||
curNav={nav}
|
||||
navigationItems={navigationItems}
|
||||
createText="menus.newApp"
|
||||
onCreate={mockOnCreate}
|
||||
onLoadMore={mockOnLoadMore}
|
||||
isLoadingMore={false}
|
||||
/>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Header Nav Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSelectedSegment = 'app'
|
||||
mockIsCurrentWorkspaceEditor = true
|
||||
})
|
||||
|
||||
it('switches to another app from the selector and clears stale app detail first', async () => {
|
||||
renderNav()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /Alpha/i }))
|
||||
fireEvent.click(await screen.findByText('Bravo'))
|
||||
|
||||
expect(mockSetAppDetail).toHaveBeenCalled()
|
||||
expect(mockPush).toHaveBeenCalledWith('/app/app-2/workflow')
|
||||
})
|
||||
|
||||
it('opens the nested create menu and emits all app creation branches', async () => {
|
||||
renderNav()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /Alpha/i }))
|
||||
fireEvent.click(await screen.findByText('menus.newApp'))
|
||||
fireEvent.click(await screen.findByText('newApp.startFromBlank'))
|
||||
fireEvent.click(await screen.findByText('newApp.startFromTemplate'))
|
||||
fireEvent.click(await screen.findByText('importDSL'))
|
||||
|
||||
expect(mockOnCreate).toHaveBeenNthCalledWith(1, 'blank')
|
||||
expect(mockOnCreate).toHaveBeenNthCalledWith(2, 'template')
|
||||
expect(mockOnCreate).toHaveBeenNthCalledWith(3, 'dsl')
|
||||
})
|
||||
|
||||
it('keeps the current nav label in sync with prop updates', async () => {
|
||||
const { rerender } = renderNav()
|
||||
|
||||
expect(screen.getByRole('button', { name: /Alpha/i })).toBeInTheDocument()
|
||||
|
||||
rerender(
|
||||
<Nav
|
||||
isApp
|
||||
icon={<span data-testid="nav-icon">icon</span>}
|
||||
activeIcon={<span data-testid="nav-icon-active">active-icon</span>}
|
||||
text="menus.apps"
|
||||
activeSegment={['apps', 'app']}
|
||||
link="/apps"
|
||||
curNav={{
|
||||
...curNav,
|
||||
name: 'Alpha Renamed',
|
||||
}}
|
||||
navigationItems={navigationItems}
|
||||
createText="menus.newApp"
|
||||
onCreate={mockOnCreate}
|
||||
onLoadMore={mockOnLoadMore}
|
||||
isLoadingMore={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /Alpha Renamed/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
163
web/__tests__/plugins/plugin-page-shell-flow.test.tsx
Normal file
163
web/__tests__/plugins/plugin-page-shell-flow.test.tsx
Normal file
@ -0,0 +1,163 @@
|
||||
import { fireEvent, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import PluginPage from '@/app/components/plugins/plugin-page'
|
||||
import { renderWithNuqs } from '@/test/nuqs-testing'
|
||||
|
||||
const mockFetchManifestFromMarketPlace = vi.fn()
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/utils')>()
|
||||
return {
|
||||
...actual,
|
||||
sleep: vi.fn(() => Promise.resolve()),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/hooks/use-document-title', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path: string) => `https://docs.example.com${path}`,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceManager: false,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
enable_marketplace: true,
|
||||
plugin_installation_permission: {
|
||||
restrict_to_marketplace_only: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useReferenceSettings: () => ({
|
||||
data: {
|
||||
permission: {
|
||||
install_permission: 'everyone',
|
||||
debug_permission: 'noOne',
|
||||
},
|
||||
},
|
||||
}),
|
||||
useMutationReferenceSettings: () => ({
|
||||
mutate: vi.fn(),
|
||||
isPending: false,
|
||||
}),
|
||||
useInvalidateReferenceSettings: () => vi.fn(),
|
||||
useInstalledPluginList: () => ({
|
||||
data: {
|
||||
total: 2,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/plugins', () => ({
|
||||
fetchManifestFromMarketPlace: (...args: unknown[]) => mockFetchManifestFromMarketPlace(...args),
|
||||
fetchBundleInfoFromMarketPlace: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/plugin-tasks', () => ({
|
||||
default: () => <div data-testid="plugin-tasks">plugin tasks</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/debug-info', () => ({
|
||||
default: () => <div data-testid="debug-info">debug info</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/install-plugin-dropdown', () => ({
|
||||
default: ({ onSwitchToMarketplaceTab }: { onSwitchToMarketplaceTab: () => void }) => (
|
||||
<button type="button" data-testid="install-plugin-dropdown" onClick={onSwitchToMarketplaceTab}>
|
||||
install plugin
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/install-plugin/install-from-marketplace', () => ({
|
||||
default: ({
|
||||
uniqueIdentifier,
|
||||
onClose,
|
||||
}: {
|
||||
uniqueIdentifier: string
|
||||
onClose: () => void
|
||||
}) => (
|
||||
<div data-testid="install-from-marketplace-modal">
|
||||
<span>{uniqueIdentifier}</span>
|
||||
<button type="button" onClick={onClose}>close-install-modal</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const renderPluginPage = (searchParams = '') => {
|
||||
return renderWithNuqs(
|
||||
<PluginPage
|
||||
plugins={<div data-testid="plugins-view">plugins view</div>}
|
||||
marketplace={<div data-testid="marketplace-view">marketplace view</div>}
|
||||
/>,
|
||||
{ searchParams },
|
||||
)
|
||||
}
|
||||
|
||||
describe('Plugin Page Shell Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFetchManifestFromMarketPlace.mockResolvedValue({
|
||||
data: {
|
||||
plugin: {
|
||||
org: 'langgenius',
|
||||
name: 'plugin-demo',
|
||||
},
|
||||
version: {
|
||||
version: '1.0.0',
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('switches from installed plugins to marketplace and syncs the active tab into the URL', async () => {
|
||||
const { onUrlUpdate } = renderPluginPage()
|
||||
|
||||
expect(screen.getByTestId('plugins-view')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('marketplace-view')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByTestId('tab-item-discover'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('marketplace-view')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const tabUpdate = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
|
||||
expect(tabUpdate.searchParams.get('tab')).toBe('discover')
|
||||
})
|
||||
|
||||
it('hydrates marketplace installation from query params and clears the install state when closed', async () => {
|
||||
const { onUrlUpdate } = renderPluginPage('?package-ids=%5B%22langgenius%2Fplugin-demo%22%5D')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchManifestFromMarketPlace).toHaveBeenCalledWith('langgenius%2Fplugin-demo')
|
||||
expect(screen.getByTestId('install-from-marketplace-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'close-install-modal' }))
|
||||
|
||||
await waitFor(() => {
|
||||
const clearUpdate = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
|
||||
expect(clearUpdate.searchParams.has('package-ids')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
155
web/__tests__/share/text-generation-mode-flow.test.tsx
Normal file
155
web/__tests__/share/text-generation-mode-flow.test.tsx
Normal file
@ -0,0 +1,155 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import TextGeneration from '@/app/components/share/text-generation'
|
||||
|
||||
const useSearchParamsMock = vi.fn(() => new URLSearchParams())
|
||||
const mockUseTextGenerationAppState = vi.fn()
|
||||
const mockUseTextGenerationBatch = vi.fn()
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useSearchParams: () => useSearchParamsMock(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
__esModule: true,
|
||||
default: () => 'pc',
|
||||
MediaType: { pc: 'pc', pad: 'pad', mobile: 'mobile' },
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/share/text-generation/hooks/use-text-generation-app-state', () => ({
|
||||
useTextGenerationAppState: (...args: unknown[]) => mockUseTextGenerationAppState(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/share/text-generation/hooks/use-text-generation-batch', () => ({
|
||||
useTextGenerationBatch: (...args: unknown[]) => mockUseTextGenerationBatch(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/share/text-generation/text-generation-sidebar', () => ({
|
||||
default: ({
|
||||
currentTab,
|
||||
onTabChange,
|
||||
}: {
|
||||
currentTab: string
|
||||
onTabChange: (tab: string) => void
|
||||
}) => (
|
||||
<div data-testid="text-generation-sidebar">
|
||||
<span data-testid="current-tab">{currentTab}</span>
|
||||
<button type="button" onClick={() => onTabChange('batch')}>switch-to-batch</button>
|
||||
<button type="button" onClick={() => onTabChange('create')}>switch-to-create</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/share/text-generation/text-generation-result-panel', () => ({
|
||||
default: ({
|
||||
isCallBatchAPI,
|
||||
resultExisted,
|
||||
}: {
|
||||
isCallBatchAPI: boolean
|
||||
resultExisted: boolean
|
||||
}) => (
|
||||
<div
|
||||
data-testid="text-generation-result-panel"
|
||||
data-batch={String(isCallBatchAPI)}
|
||||
data-result={String(resultExisted)}
|
||||
/>
|
||||
),
|
||||
}))
|
||||
|
||||
const createReadyAppState = () => ({
|
||||
accessMode: 'public',
|
||||
appId: 'app-123',
|
||||
appSourceType: 'published',
|
||||
customConfig: {
|
||||
remove_webapp_brand: false,
|
||||
replace_webapp_logo: '',
|
||||
},
|
||||
handleRemoveSavedMessage: vi.fn(),
|
||||
handleSaveMessage: vi.fn(),
|
||||
moreLikeThisConfig: {
|
||||
enabled: true,
|
||||
},
|
||||
promptConfig: {
|
||||
user_input_form: [],
|
||||
},
|
||||
savedMessages: [],
|
||||
siteInfo: {
|
||||
title: 'Text Generation',
|
||||
},
|
||||
systemFeatures: {
|
||||
branding: {
|
||||
enabled: false,
|
||||
workspace_logo: null,
|
||||
},
|
||||
},
|
||||
textToSpeechConfig: {
|
||||
enabled: true,
|
||||
},
|
||||
visionConfig: null,
|
||||
})
|
||||
|
||||
const createBatchState = () => ({
|
||||
allFailedTaskList: [],
|
||||
allSuccessTaskList: [],
|
||||
allTaskList: [],
|
||||
allTasksRun: false,
|
||||
controlRetry: 0,
|
||||
exportRes: vi.fn(),
|
||||
handleCompleted: vi.fn(),
|
||||
handleRetryAllFailedTask: vi.fn(),
|
||||
handleRunBatch: vi.fn(),
|
||||
isCallBatchAPI: false,
|
||||
noPendingTask: true,
|
||||
resetBatchExecution: vi.fn(),
|
||||
setIsCallBatchAPI: vi.fn(),
|
||||
showTaskList: false,
|
||||
})
|
||||
|
||||
describe('Text Generation Mode Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
useSearchParamsMock.mockReturnValue(new URLSearchParams())
|
||||
mockUseTextGenerationAppState.mockReturnValue(createReadyAppState())
|
||||
mockUseTextGenerationBatch.mockReturnValue(createBatchState())
|
||||
})
|
||||
|
||||
it('shows the loading state before app metadata is ready', () => {
|
||||
mockUseTextGenerationAppState.mockReturnValue({
|
||||
...createReadyAppState(),
|
||||
appId: '',
|
||||
promptConfig: null,
|
||||
siteInfo: null,
|
||||
})
|
||||
|
||||
render(<TextGeneration />)
|
||||
|
||||
expect(screen.getByRole('status', { name: 'appApi.loading' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('hydrates the initial tab from the mode query parameter and lets the sidebar switch it', () => {
|
||||
useSearchParamsMock.mockReturnValue(new URLSearchParams('mode=batch'))
|
||||
|
||||
render(<TextGeneration />)
|
||||
|
||||
expect(screen.getByTestId('current-tab')).toHaveTextContent('batch')
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'switch-to-create' }))
|
||||
|
||||
expect(screen.getByTestId('current-tab')).toHaveTextContent('create')
|
||||
})
|
||||
|
||||
it('falls back to create mode for unsupported query values', () => {
|
||||
useSearchParamsMock.mockReturnValue(new URLSearchParams('mode=unsupported'))
|
||||
|
||||
render(<TextGeneration />)
|
||||
|
||||
expect(screen.getByTestId('current-tab')).toHaveTextContent('create')
|
||||
expect(screen.getByTestId('text-generation-result-panel')).toHaveAttribute('data-batch', 'false')
|
||||
})
|
||||
})
|
||||
205
web/__tests__/tools/provider-list-shell-flow.test.tsx
Normal file
205
web/__tests__/tools/provider-list-shell-flow.test.tsx
Normal file
@ -0,0 +1,205 @@
|
||||
import { fireEvent, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import ProviderList from '@/app/components/tools/provider-list'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { renderWithNuqs } from '@/test/nuqs-testing'
|
||||
|
||||
const mockInvalidateInstalledPluginList = vi.fn()
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: Record<string, unknown>) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
enable_marketplace: true,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/hooks', () => ({
|
||||
useTags: () => ({
|
||||
getTagLabel: (name: string) => name,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useAllToolProviders: () => ({
|
||||
data: [
|
||||
{
|
||||
id: 'builtin-plugin',
|
||||
name: 'plugin-tool',
|
||||
author: 'Dify',
|
||||
description: { en_US: 'Plugin Tool' },
|
||||
icon: 'icon-plugin',
|
||||
label: { en_US: 'Plugin Tool' },
|
||||
type: CollectionType.builtIn,
|
||||
team_credentials: {},
|
||||
is_team_authorization: false,
|
||||
allow_delete: false,
|
||||
labels: ['search'],
|
||||
plugin_id: 'langgenius/plugin-tool',
|
||||
},
|
||||
{
|
||||
id: 'builtin-basic',
|
||||
name: 'basic-tool',
|
||||
author: 'Dify',
|
||||
description: { en_US: 'Basic Tool' },
|
||||
icon: 'icon-basic',
|
||||
label: { en_US: 'Basic Tool' },
|
||||
type: CollectionType.builtIn,
|
||||
team_credentials: {},
|
||||
is_team_authorization: false,
|
||||
allow_delete: false,
|
||||
labels: ['utility'],
|
||||
},
|
||||
],
|
||||
refetch: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useCheckInstalled: ({ enabled }: { enabled: boolean }) => ({
|
||||
data: enabled
|
||||
? {
|
||||
plugins: [{
|
||||
plugin_id: 'langgenius/plugin-tool',
|
||||
declaration: {
|
||||
category: 'tool',
|
||||
},
|
||||
}],
|
||||
}
|
||||
: null,
|
||||
}),
|
||||
useInvalidateInstalledPluginList: () => mockInvalidateInstalledPluginList,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/labels/filter', () => ({
|
||||
default: ({ onChange }: { onChange: (value: string[]) => void }) => (
|
||||
<div data-testid="tool-label-filter">
|
||||
<button type="button" onClick={() => onChange(['search'])}>apply-search-filter</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/card', () => ({
|
||||
default: ({ payload, className }: { payload: { name: string }, className?: string }) => (
|
||||
<div data-testid={`tool-card-${payload.name}`} className={className}>
|
||||
{payload.name}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/card/card-more-info', () => ({
|
||||
default: ({ tags }: { tags: string[] }) => <div data-testid="tool-card-more-info">{tags.join(',')}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/provider/detail', () => ({
|
||||
default: ({ collection, onHide }: { collection: { name: string }, onHide: () => void }) => (
|
||||
<div data-testid="tool-provider-detail">
|
||||
<span>{collection.name}</span>
|
||||
<button type="button" onClick={onHide}>close-provider-detail</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-detail-panel', () => ({
|
||||
default: ({
|
||||
detail,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: {
|
||||
detail?: { plugin_id: string }
|
||||
onHide: () => void
|
||||
onUpdate: () => void
|
||||
}) => detail
|
||||
? (
|
||||
<div data-testid="tool-plugin-detail-panel">
|
||||
<span>{detail.plugin_id}</span>
|
||||
<button type="button" onClick={onUpdate}>update-plugin-detail</button>
|
||||
<button type="button" onClick={onHide}>close-plugin-detail</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/provider/empty', () => ({
|
||||
default: () => <div data-testid="workflow-empty">workflow empty</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/empty', () => ({
|
||||
default: ({ text }: { text: string }) => <div data-testid="tools-empty">{text}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/marketplace', () => ({
|
||||
default: ({
|
||||
isMarketplaceArrowVisible,
|
||||
showMarketplacePanel,
|
||||
}: {
|
||||
isMarketplaceArrowVisible: boolean
|
||||
showMarketplacePanel: () => void
|
||||
}) => (
|
||||
<button type="button" data-testid="marketplace-arrow" data-visible={String(isMarketplaceArrowVisible)} onClick={showMarketplacePanel}>
|
||||
marketplace-arrow
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/marketplace/hooks', () => ({
|
||||
useMarketplace: () => ({
|
||||
handleScroll: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/tools/mcp', () => ({
|
||||
default: ({ searchText }: { searchText: string }) => <div data-testid="mcp-list">{searchText}</div>,
|
||||
}))
|
||||
|
||||
const renderProviderList = (searchParams = '') => {
|
||||
return renderWithNuqs(<ProviderList />, { searchParams })
|
||||
}
|
||||
|
||||
describe('Tool Provider List Shell Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
Element.prototype.scrollTo = vi.fn()
|
||||
})
|
||||
|
||||
it('opens a plugin-backed provider detail panel and invalidates installed plugins on update', async () => {
|
||||
renderProviderList('?category=builtin')
|
||||
|
||||
fireEvent.click(screen.getByTestId('tool-card-plugin-tool'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('tool-plugin-detail-panel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'update-plugin-detail' }))
|
||||
expect(mockInvalidateInstalledPluginList).toHaveBeenCalledTimes(1)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'close-plugin-detail' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('tool-plugin-detail-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('scrolls to the marketplace section and syncs workflow tab selection into the URL', async () => {
|
||||
const { onUrlUpdate } = renderProviderList('?category=builtin')
|
||||
|
||||
fireEvent.click(screen.getByTestId('marketplace-arrow'))
|
||||
expect(Element.prototype.scrollTo).toHaveBeenCalled()
|
||||
|
||||
fireEvent.click(screen.getByTestId('tab-item-workflow'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('workflow-empty')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
|
||||
expect(update.searchParams.get('category')).toBe('workflow')
|
||||
})
|
||||
})
|
||||
@ -18,6 +18,7 @@ const mockInvalidDatasetDetail = vi.fn()
|
||||
const mockExportPipeline = vi.fn()
|
||||
const mockCheckIsUsedInApp = vi.fn()
|
||||
const mockDeleteDataset = vi.fn()
|
||||
const mockToast = vi.fn()
|
||||
|
||||
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
|
||||
id: 'dataset-1',
|
||||
@ -111,6 +112,10 @@ vi.mock('@/service/datasets', () => ({
|
||||
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: (...args: unknown[]) => mockToast(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/datasets/rename-modal', () => ({
|
||||
default: ({
|
||||
show,
|
||||
@ -225,4 +230,49 @@ describe('Dropdown callback coverage', () => {
|
||||
expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show the used-by-app confirmation copy when the dataset is referenced by apps', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockCheckIsUsedInApp.mockResolvedValueOnce({ is_using: true })
|
||||
|
||||
render(<Dropdown expand />)
|
||||
|
||||
await user.click(screen.getByTestId('portal-trigger'))
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dataset.datasetUsedByApp')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should surface an export failure toast when pipeline export fails', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockExportPipeline.mockRejectedValueOnce(new Error('export failed'))
|
||||
|
||||
render(<Dropdown expand />)
|
||||
|
||||
await user.click(screen.getByTestId('portal-trigger'))
|
||||
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith('app.exportFailed', { type: 'error' })
|
||||
})
|
||||
})
|
||||
|
||||
it('should surface the backend message when checking app usage fails', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockCheckIsUsedInApp.mockRejectedValueOnce({
|
||||
json: vi.fn().mockResolvedValue({ message: 'check failed' }),
|
||||
})
|
||||
|
||||
render(<Dropdown expand />)
|
||||
|
||||
await user.click(screen.getByTestId('portal-trigger'))
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith('check failed', { type: 'error' })
|
||||
})
|
||||
expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { RiEditLine } from '@remixicon/react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { createEvent, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import {
|
||||
@ -218,6 +218,31 @@ describe('MenuItem', () => {
|
||||
// Assert
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should stop propagation before invoking the handler', () => {
|
||||
const parentClick = vi.fn()
|
||||
const handleClick = vi.fn()
|
||||
|
||||
render(
|
||||
<div onClick={parentClick}>
|
||||
<MenuItem name="Edit" Icon={RiEditLine} handleClick={handleClick} />
|
||||
</div>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Edit'))
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
expect(parentClick).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not crash when no click handler is provided', () => {
|
||||
render(<MenuItem name="Edit" Icon={RiEditLine} />)
|
||||
|
||||
const event = createEvent.click(screen.getByText('Edit'))
|
||||
fireEvent(screen.getByText('Edit'), event)
|
||||
|
||||
expect(event.defaultPrevented).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -265,6 +290,47 @@ describe('Menu', () => {
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Interactions', () => {
|
||||
it('should invoke the rename callback when edit is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const openRenameModal = vi.fn()
|
||||
|
||||
render(
|
||||
<Menu
|
||||
showDelete
|
||||
openRenameModal={openRenameModal}
|
||||
handleExportPipeline={vi.fn()}
|
||||
detectIsUsedByApp={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
|
||||
expect(openRenameModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should invoke export and delete callbacks from their menu items', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleExportPipeline = vi.fn()
|
||||
const detectIsUsedByApp = vi.fn()
|
||||
|
||||
render(
|
||||
<Menu
|
||||
showDelete
|
||||
openRenameModal={vi.fn()}
|
||||
handleExportPipeline={handleExportPipeline}
|
||||
detectIsUsedByApp={detectIsUsedByApp}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
expect(handleExportPipeline).toHaveBeenCalledTimes(1)
|
||||
expect(detectIsUsedByApp).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dropdown', () => {
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
/* eslint-disable ts/no-explicit-any */
|
||||
import type { ReactNode } from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { act, fireEvent, screen, waitFor } from '@testing-library/react'
|
||||
import { renderWithNuqs } from '@/test/nuqs-testing'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import ConversationList from '../list'
|
||||
|
||||
const mockFetchChatMessages = vi.fn()
|
||||
const mockUpdateLogMessageFeedbacks = vi.fn()
|
||||
const mockUpdateLogMessageAnnotations = vi.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
const mockOnRefresh = vi.fn()
|
||||
const mockSetCurrentLogItem = vi.fn()
|
||||
const mockSetShowPromptLogModal = vi.fn()
|
||||
@ -17,7 +16,6 @@ const mockSetShowMessageLogModal = vi.fn()
|
||||
const mockCompletionRefetch = vi.fn()
|
||||
const mockDelAnnotation = vi.fn()
|
||||
|
||||
let mockSearchParams = new URLSearchParams()
|
||||
let mockChatConversationDetail: Record<string, unknown> | undefined
|
||||
let mockCompletionConversationDetail: Record<string, unknown> | undefined
|
||||
let mockShowMessageLogModal = false
|
||||
@ -53,18 +51,6 @@ vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
}),
|
||||
usePathname: () => '/apps/app-1/logs',
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => mockSearchParams.get(key),
|
||||
toString: () => mockSearchParams.toString(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-log', () => ({
|
||||
useChatConversationDetail: () => ({
|
||||
data: mockChatConversationDetail,
|
||||
@ -256,10 +242,28 @@ const createChatMessage = (id: string, overrides: Record<string, unknown> = {})
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderConversationList = ({
|
||||
appDetail = { id: 'app-1', mode: AppModeEnum.CHAT } as any,
|
||||
logs = createLogs() as any,
|
||||
searchParams = '?page=2',
|
||||
}: {
|
||||
appDetail?: any
|
||||
logs?: any
|
||||
searchParams?: string
|
||||
} = {}) => {
|
||||
return renderWithNuqs(
|
||||
<ConversationList
|
||||
appDetail={appDetail}
|
||||
logs={logs}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
{ searchParams },
|
||||
)
|
||||
}
|
||||
|
||||
describe('ConversationList', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSearchParams = new URLSearchParams('page=2')
|
||||
mockChatConversationDetail = undefined
|
||||
mockCompletionConversationDetail = undefined
|
||||
mockShowMessageLogModal = false
|
||||
@ -273,34 +277,29 @@ describe('ConversationList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should render chat rows and push the conversation id into the url when a row is clicked', () => {
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
it('should render chat rows and push the conversation id into the url when a row is clicked', async () => {
|
||||
const { onUrlUpdate } = renderConversationList()
|
||||
|
||||
expect(screen.getByText('hello world')).toBeInTheDocument()
|
||||
expect(screen.getAllByText('formatted-1710000000')).toHaveLength(2)
|
||||
|
||||
fireEvent.click(screen.getByText('hello world'))
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/apps/app-1/logs?page=2&conversation_id=conversation-1', { scroll: false })
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
await waitFor(() => {
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.get('page')).toBe('2')
|
||||
expect(update.searchParams.get('conversation_id')).toBe('conversation-1')
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
|
||||
it('should close the drawer, refresh, and clear modal flags', () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
it('should close the drawer, refresh, and clear modal flags', async () => {
|
||||
const { onUrlUpdate } = renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('close-drawer'))
|
||||
|
||||
@ -308,11 +307,18 @@ describe('ConversationList', () => {
|
||||
expect(mockSetShowPromptLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockSetShowMessageLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockReplace).toHaveBeenCalledWith('/apps/app-1/logs?page=2', { scroll: false })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.get('page')).toBe('2')
|
||||
expect(update.searchParams.has('conversation_id')).toBe(false)
|
||||
expect(update.options.history).toBe('replace')
|
||||
})
|
||||
|
||||
it('should render chat conversation details and submit feedback from the chat panel', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -355,13 +361,9 @@ describe('ConversationList', () => {
|
||||
mockShowMessageLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-1' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchChatMessages).toHaveBeenCalledWith({
|
||||
@ -396,7 +398,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should render completion details and refetch after feedback updates', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -423,13 +424,11 @@ describe('ConversationList', () => {
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
|
||||
logs={createCompletionLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
logs: createCompletionLogs() as any,
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('text-generation')).toBeInTheDocument()
|
||||
@ -454,64 +453,61 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should render chatflow status cells and feedback counters for advanced chat logs', () => {
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any}
|
||||
logs={{
|
||||
data: [
|
||||
{
|
||||
id: 'conversation-pending',
|
||||
name: 'Pending row',
|
||||
from_account_name: 'user-a',
|
||||
read_at: 1710000001,
|
||||
message_count: 3,
|
||||
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 2, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 1 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-success',
|
||||
name: 'Success row',
|
||||
from_account_name: 'user-b',
|
||||
read_at: 1710000001,
|
||||
message_count: 4,
|
||||
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-partial',
|
||||
name: 'Partial row',
|
||||
from_account_name: 'user-c',
|
||||
read_at: 1710000001,
|
||||
message_count: 5,
|
||||
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-failure',
|
||||
name: 'Failure row',
|
||||
from_account_name: 'user-d',
|
||||
read_at: 1710000001,
|
||||
message_count: 1,
|
||||
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
],
|
||||
} as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any,
|
||||
logs: {
|
||||
data: [
|
||||
{
|
||||
id: 'conversation-pending',
|
||||
name: 'Pending row',
|
||||
from_account_name: 'user-a',
|
||||
read_at: 1710000001,
|
||||
message_count: 3,
|
||||
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 2, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 1 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-success',
|
||||
name: 'Success row',
|
||||
from_account_name: 'user-b',
|
||||
read_at: 1710000001,
|
||||
message_count: 4,
|
||||
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-partial',
|
||||
name: 'Partial row',
|
||||
from_account_name: 'user-c',
|
||||
read_at: 1710000001,
|
||||
message_count: 5,
|
||||
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-failure',
|
||||
name: 'Failure row',
|
||||
from_account_name: 'user-d',
|
||||
read_at: 1710000001,
|
||||
message_count: 1,
|
||||
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
],
|
||||
} as any,
|
||||
})
|
||||
|
||||
expect(screen.getByText('Pending')).toBeInTheDocument()
|
||||
expect(screen.getByText('Success')).toBeInTheDocument()
|
||||
@ -522,7 +518,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should support annotation changes, modal closing, and paginated scroll loading in the detail drawer', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -568,13 +563,9 @@ describe('ConversationList', () => {
|
||||
has_more: false,
|
||||
})
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('chat-panel')).toBeInTheDocument()
|
||||
@ -609,7 +600,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should close the prompt log modal from completion detail drawers', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -636,13 +626,11 @@ describe('ConversationList', () => {
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
|
||||
logs={createCompletionLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
logs: createCompletionLogs() as any,
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
expect(await screen.findByTestId('prompt-log-modal')).toBeInTheDocument()
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import dayjs from 'dayjs'
|
||||
import timezone from 'dayjs/plugin/timezone'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -33,7 +34,6 @@ import { WorkflowContextProvider } from '@/app/components/workflow/context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log'
|
||||
@ -46,7 +46,6 @@ import {
|
||||
applyAnnotationEdited,
|
||||
applyAnnotationRemoved,
|
||||
buildChatThreadState,
|
||||
buildConversationUrl,
|
||||
getCompletionMessageFiles,
|
||||
getConversationRowValues,
|
||||
getDetailVarList,
|
||||
@ -674,10 +673,7 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string }
|
||||
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
|
||||
const { t } = useTranslation()
|
||||
const { formatTime } = useTimestamp()
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
|
||||
const [conversationIdInUrl, setConversationIdInUrl] = useQueryState('conversation_id', parseAsString)
|
||||
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
@ -697,8 +693,6 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
|
||||
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
|
||||
|
||||
const buildUrlWithConversation = useCallback((conversationId?: string) => buildConversationUrl(pathname, searchParams.toString(), conversationId), [pathname, searchParams])
|
||||
|
||||
const handleRowClick = useCallback((log: ConversationListItem) => {
|
||||
if (conversationIdInUrl === log.id) {
|
||||
if (!showDrawer)
|
||||
@ -717,8 +711,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
if (currentConversation?.id !== log.id)
|
||||
setCurrentConversation(undefined)
|
||||
|
||||
router.push(buildUrlWithConversation(log.id), { scroll: false })
|
||||
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
|
||||
void setConversationIdInUrl(log.id, { history: 'push' })
|
||||
}, [conversationIdInUrl, currentConversation, setConversationIdInUrl, showDrawer])
|
||||
|
||||
const currentConversationId = currentConversation?.id
|
||||
|
||||
@ -755,7 +749,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
|
||||
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
|
||||
pendingConversationCacheRef.current = undefined
|
||||
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
|
||||
}, [conversationIdInUrl, currentConversation, currentConversationId, logs?.data, showDrawer])
|
||||
|
||||
const onCloseDrawer = useCallback(() => {
|
||||
onRefresh()
|
||||
@ -769,8 +763,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
closingConversationIdRef.current = conversationIdInUrl ?? null
|
||||
|
||||
if (conversationIdInUrl)
|
||||
router.replace(buildUrlWithConversation(), { scroll: false })
|
||||
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
|
||||
void setConversationIdInUrl(null, { history: 'replace' })
|
||||
}, [conversationIdInUrl, onRefresh, setConversationIdInUrl, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
|
||||
|
||||
// Annotated data needs to be highlighted
|
||||
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
|
||||
|
||||
24
web/app/components/apps/__tests__/app-card-skeleton.spec.tsx
Normal file
24
web/app/components/apps/__tests__/app-card-skeleton.spec.tsx
Normal file
@ -0,0 +1,24 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import { AppCardSkeleton } from '../app-card-skeleton'
|
||||
|
||||
describe('AppCardSkeleton', () => {
|
||||
it('should render six skeleton cards by default', () => {
|
||||
const { container } = render(<AppCardSkeleton />)
|
||||
|
||||
expect(container.childElementCount).toBe(6)
|
||||
expect(AppCardSkeleton.displayName).toBe('AppCardSkeleton')
|
||||
})
|
||||
|
||||
it('should respect the custom skeleton count and card classes', () => {
|
||||
const { container } = render(<AppCardSkeleton count={2} />)
|
||||
|
||||
expect(container.childElementCount).toBe(2)
|
||||
expect(container.firstElementChild).toHaveClass(
|
||||
'h-[160px]',
|
||||
'rounded-xl',
|
||||
'border-[0.5px]',
|
||||
'bg-components-card-bg',
|
||||
'p-4',
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,144 @@
|
||||
import type { IChatItem } from '../type'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { fetchAgentLogDetail } from '@/service/log'
|
||||
import ChatLogModals from '../chat-log-modals'
|
||||
|
||||
vi.mock('@/service/log', () => ({
|
||||
fetchAgentLogDetail: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('ChatLogModals', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
useAppStore.setState({ appDetail: { id: 'app-1' } as ReturnType<typeof useAppStore.getState>['appDetail'] })
|
||||
})
|
||||
|
||||
// Modal visibility should follow the two booleans unless log modals are globally hidden.
|
||||
describe('Rendering', () => {
|
||||
it('should render real prompt and agent log modals when enabled', async () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
render(
|
||||
<ChatLogModals
|
||||
width={480}
|
||||
currentLogItem={{
|
||||
id: 'log-1',
|
||||
isAnswer: true,
|
||||
content: 'reply',
|
||||
input: { question: 'hello' },
|
||||
log: [{ role: 'user', text: 'Prompt body' }],
|
||||
conversationId: 'conversation-1',
|
||||
} as IChatItem}
|
||||
showPromptLogModal={true}
|
||||
showAgentLogModal={true}
|
||||
setCurrentLogItem={vi.fn()}
|
||||
setShowPromptLogModal={vi.fn()}
|
||||
setShowAgentLogModal={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('PROMPT LOG')).toBeInTheDocument()
|
||||
expect(screen.getByText('Prompt body')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('heading', { name: /appLog.runDetail.workflowTitle/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render nothing when hideLogModal is true', () => {
|
||||
render(
|
||||
<ChatLogModals
|
||||
width={320}
|
||||
currentLogItem={{
|
||||
id: 'log-2',
|
||||
isAnswer: true,
|
||||
content: 'reply',
|
||||
log: [{ role: 'user', text: 'Prompt body' }],
|
||||
conversationId: 'conversation-2',
|
||||
} as IChatItem}
|
||||
showPromptLogModal={true}
|
||||
showAgentLogModal={true}
|
||||
hideLogModal={true}
|
||||
setCurrentLogItem={vi.fn()}
|
||||
setShowPromptLogModal={vi.fn()}
|
||||
setShowAgentLogModal={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('PROMPT LOG')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('heading', { name: /appLog.runDetail.workflowTitle/i })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Cancel actions should clear the current item and close only the targeted modal.
|
||||
describe('User Interactions', () => {
|
||||
it('should close the prompt log modal through the real close action', async () => {
|
||||
const user = userEvent.setup()
|
||||
const setCurrentLogItem = vi.fn()
|
||||
const setShowPromptLogModal = vi.fn()
|
||||
const setShowAgentLogModal = vi.fn()
|
||||
|
||||
render(
|
||||
<ChatLogModals
|
||||
width={480}
|
||||
currentLogItem={{
|
||||
id: 'log-3',
|
||||
isAnswer: true,
|
||||
content: 'reply',
|
||||
input: { question: 'hello' },
|
||||
log: [{ role: 'user', text: 'Prompt body' }],
|
||||
} as IChatItem}
|
||||
showPromptLogModal={true}
|
||||
showAgentLogModal={false}
|
||||
setCurrentLogItem={setCurrentLogItem}
|
||||
setShowPromptLogModal={setShowPromptLogModal}
|
||||
setShowAgentLogModal={setShowAgentLogModal}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('close-btn-container'))
|
||||
|
||||
expect(setCurrentLogItem).toHaveBeenCalled()
|
||||
expect(setShowPromptLogModal).toHaveBeenCalledWith(false)
|
||||
expect(setShowAgentLogModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should close the agent log modal through the real close action', async () => {
|
||||
const user = userEvent.setup()
|
||||
const setCurrentLogItem = vi.fn()
|
||||
const setShowPromptLogModal = vi.fn()
|
||||
const setShowAgentLogModal = vi.fn()
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
render(
|
||||
<ChatLogModals
|
||||
width={480}
|
||||
currentLogItem={{
|
||||
id: 'log-4',
|
||||
isAnswer: true,
|
||||
content: 'reply',
|
||||
input: { question: 'hello' },
|
||||
log: [{ role: 'user', text: 'Prompt body' }],
|
||||
conversationId: 'conversation-4',
|
||||
} as IChatItem}
|
||||
showPromptLogModal={false}
|
||||
showAgentLogModal={true}
|
||||
setCurrentLogItem={setCurrentLogItem}
|
||||
setShowPromptLogModal={setShowPromptLogModal}
|
||||
setShowAgentLogModal={setShowAgentLogModal}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('heading', { name: /appLog.runDetail.workflowTitle/i })).toBeInTheDocument()
|
||||
})
|
||||
await user.click(screen.getByRole('heading', { name: /appLog.runDetail.workflowTitle/i }).nextElementSibling as HTMLElement)
|
||||
|
||||
expect(setCurrentLogItem).toHaveBeenCalled()
|
||||
expect(setShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
expect(setShowPromptLogModal).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,293 @@
|
||||
import type { ChatItem } from '../../types'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
} from 'vitest'
|
||||
import { useChatLayout } from '../use-chat-layout'
|
||||
|
||||
type ResizeCallback = (entries: ResizeObserverEntry[], observer: ResizeObserver) => void
|
||||
|
||||
let capturedResizeCallbacks: ResizeCallback[] = []
|
||||
let disconnectSpy: ReturnType<typeof vi.fn>
|
||||
let rafCallbacks: FrameRequestCallback[] = []
|
||||
|
||||
const makeChatItem = (overrides: Partial<ChatItem> = {}): ChatItem => ({
|
||||
id: `item-${Math.random().toString(36).slice(2)}`,
|
||||
content: 'Test content',
|
||||
isAnswer: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const makeResizeEntry = (blockSize: number, inlineSize: number): ResizeObserverEntry => ({
|
||||
borderBoxSize: [{ blockSize, inlineSize } as ResizeObserverSize],
|
||||
contentBoxSize: [{ blockSize, inlineSize } as ResizeObserverSize],
|
||||
contentRect: new DOMRect(0, 0, inlineSize, blockSize),
|
||||
devicePixelContentBoxSize: [{ blockSize, inlineSize } as ResizeObserverSize],
|
||||
target: document.createElement('div'),
|
||||
})
|
||||
|
||||
const assignMetric = (node: HTMLElement, key: 'clientWidth' | 'clientHeight' | 'scrollHeight', value: number) => {
|
||||
Object.defineProperty(node, key, {
|
||||
configurable: true,
|
||||
value,
|
||||
})
|
||||
}
|
||||
|
||||
const LayoutHarness = ({
|
||||
chatList,
|
||||
sidebarCollapseState,
|
||||
attachRefs = true,
|
||||
}: {
|
||||
chatList: ChatItem[]
|
||||
sidebarCollapseState?: boolean
|
||||
attachRefs?: boolean
|
||||
}) => {
|
||||
const {
|
||||
width,
|
||||
chatContainerRef,
|
||||
chatContainerInnerRef,
|
||||
chatFooterRef,
|
||||
chatFooterInnerRef,
|
||||
} = useChatLayout({ chatList, sidebarCollapseState })
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
data-testid="chat-container"
|
||||
ref={(node) => {
|
||||
chatContainerRef.current = attachRefs ? node : null
|
||||
if (node && attachRefs) {
|
||||
assignMetric(node, 'clientWidth', 400)
|
||||
assignMetric(node, 'clientHeight', 240)
|
||||
assignMetric(node, 'scrollHeight', 640)
|
||||
if (!node.dataset.metricsReady) {
|
||||
node.scrollTop = 0
|
||||
node.dataset.metricsReady = 'true'
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div
|
||||
data-testid="chat-container-inner"
|
||||
ref={(node) => {
|
||||
chatContainerInnerRef.current = attachRefs ? node : null
|
||||
if (node && attachRefs)
|
||||
assignMetric(node, 'clientWidth', 360)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
data-testid="chat-footer"
|
||||
ref={(node) => {
|
||||
chatFooterRef.current = attachRefs ? node : null
|
||||
}}
|
||||
>
|
||||
<div
|
||||
data-testid="chat-footer-inner"
|
||||
ref={(node) => {
|
||||
chatFooterInnerRef.current = attachRefs ? node : null
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<output data-testid="layout-width">{width}</output>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const flushAnimationFrames = () => {
|
||||
const queuedCallbacks = [...rafCallbacks]
|
||||
rafCallbacks = []
|
||||
queuedCallbacks.forEach(callback => callback(0))
|
||||
}
|
||||
|
||||
describe('useChatLayout', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
capturedResizeCallbacks = []
|
||||
disconnectSpy = vi.fn()
|
||||
rafCallbacks = []
|
||||
|
||||
Object.defineProperty(document.body, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 1024,
|
||||
})
|
||||
|
||||
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => {
|
||||
rafCallbacks.push(cb)
|
||||
return rafCallbacks.length
|
||||
})
|
||||
|
||||
vi.stubGlobal('ResizeObserver', class {
|
||||
constructor(cb: ResizeCallback) {
|
||||
capturedResizeCallbacks.push(cb)
|
||||
}
|
||||
|
||||
observe() { }
|
||||
unobserve() { }
|
||||
disconnect = disconnectSpy
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
// The hook should compute shell dimensions and auto-scroll when enough chat items exist.
|
||||
describe('Layout Calculation', () => {
|
||||
it('should auto-scroll and compute the chat shell widths on mount', () => {
|
||||
const addSpy = vi.spyOn(window, 'addEventListener')
|
||||
|
||||
render(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q1' }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
]}
|
||||
sidebarCollapseState={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
flushAnimationFrames()
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('layout-width')).toHaveTextContent('600')
|
||||
expect(screen.getByTestId('chat-footer').style.width).toBe('400px')
|
||||
expect(screen.getByTestId('chat-footer-inner').style.width).toBe('360px')
|
||||
expect((screen.getByTestId('chat-container') as HTMLDivElement).scrollTop).toBe(640)
|
||||
expect(addSpy).toHaveBeenCalledWith('resize', expect.any(Function))
|
||||
})
|
||||
})
|
||||
|
||||
// Resize observers should keep padding and widths in sync, then fully clean up on unmount.
|
||||
describe('Resize Observers', () => {
|
||||
it('should react to observer updates and disconnect both observers on unmount', () => {
|
||||
const removeSpy = vi.spyOn(window, 'removeEventListener')
|
||||
const { unmount } = render(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q1' }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
capturedResizeCallbacks[0]?.([makeResizeEntry(80, 400)], {} as ResizeObserver)
|
||||
})
|
||||
expect(screen.getByTestId('chat-container').style.paddingBottom).toBe('80px')
|
||||
|
||||
act(() => {
|
||||
capturedResizeCallbacks[1]?.([makeResizeEntry(50, 560)], {} as ResizeObserver)
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer').style.width).toBe('560px')
|
||||
|
||||
unmount()
|
||||
|
||||
expect(removeSpy).toHaveBeenCalledWith('resize', expect.any(Function))
|
||||
expect(disconnectSpy).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should respect manual scrolling until a new first message arrives and safely ignore missing refs', () => {
|
||||
const { rerender } = render(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q1' }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
const container = screen.getByTestId('chat-container') as HTMLDivElement
|
||||
|
||||
act(() => {
|
||||
fireEvent.scroll(container)
|
||||
flushAnimationFrames()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
container.scrollTop = 10
|
||||
fireEvent.scroll(container)
|
||||
})
|
||||
|
||||
rerender(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q1' }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
makeChatItem({ id: 'a2', isAnswer: true }),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
flushAnimationFrames()
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
container.scrollTop = 420
|
||||
fireEvent.scroll(container)
|
||||
})
|
||||
|
||||
rerender(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q2' }),
|
||||
makeChatItem({ id: 'a3', isAnswer: true }),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
flushAnimationFrames()
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(container.scrollTop).toBe(640)
|
||||
|
||||
rerender(
|
||||
<LayoutHarness
|
||||
chatList={[
|
||||
makeChatItem({ id: 'q2' }),
|
||||
makeChatItem({ id: 'a3', isAnswer: true }),
|
||||
]}
|
||||
attachRefs={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
fireEvent.scroll(container)
|
||||
flushAnimationFrames()
|
||||
})
|
||||
})
|
||||
|
||||
it('should keep the hook stable when the DOM refs are not attached', () => {
|
||||
render(
|
||||
<LayoutHarness
|
||||
chatList={[makeChatItem({ id: 'q1' })]}
|
||||
sidebarCollapseState={true}
|
||||
attachRefs={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
act(() => {
|
||||
flushAnimationFrames()
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('layout-width')).toHaveTextContent('0')
|
||||
expect(capturedResizeCallbacks).toHaveLength(0)
|
||||
expect(screen.getByTestId('chat-footer').style.width).toBe('')
|
||||
expect(screen.getByTestId('chat-footer-inner').style.width).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
56
web/app/components/base/chat/chat/chat-log-modals.tsx
Normal file
56
web/app/components/base/chat/chat/chat-log-modals.tsx
Normal file
@ -0,0 +1,56 @@
|
||||
import type { FC } from 'react'
|
||||
import type { IChatItem } from './type'
|
||||
import AgentLogModal from '@/app/components/base/agent-log-modal'
|
||||
import PromptLogModal from '@/app/components/base/prompt-log-modal'
|
||||
|
||||
type ChatLogModalsProps = {
|
||||
width: number
|
||||
currentLogItem?: IChatItem
|
||||
showPromptLogModal: boolean
|
||||
showAgentLogModal: boolean
|
||||
hideLogModal?: boolean
|
||||
setCurrentLogItem: (item?: IChatItem) => void
|
||||
setShowPromptLogModal: (showPromptLogModal: boolean) => void
|
||||
setShowAgentLogModal: (showAgentLogModal: boolean) => void
|
||||
}
|
||||
|
||||
const ChatLogModals: FC<ChatLogModalsProps> = ({
|
||||
width,
|
||||
currentLogItem,
|
||||
showPromptLogModal,
|
||||
showAgentLogModal,
|
||||
hideLogModal,
|
||||
setCurrentLogItem,
|
||||
setShowPromptLogModal,
|
||||
setShowAgentLogModal,
|
||||
}) => {
|
||||
if (hideLogModal)
|
||||
return null
|
||||
|
||||
return (
|
||||
<>
|
||||
{showPromptLogModal && (
|
||||
<PromptLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowPromptLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{showAgentLogModal && (
|
||||
<AgentLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowAgentLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default ChatLogModals
|
||||
@ -13,26 +13,19 @@ import type {
|
||||
import type { InputForm } from './type'
|
||||
import type { Emoji } from '@/app/components/tools/types'
|
||||
import type { AppData } from '@/models/share'
|
||||
import { debounce } from 'es-toolkit/compat'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import AgentLogModal from '@/app/components/base/agent-log-modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import PromptLogModal from '@/app/components/base/prompt-log-modal'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Answer from './answer'
|
||||
import ChatInputArea from './chat-input-area'
|
||||
import ChatLogModals from './chat-log-modals'
|
||||
import { ChatContextProvider } from './context-provider'
|
||||
import Question from './question'
|
||||
import TryToAsk from './try-to-ask'
|
||||
import { useChatLayout } from './use-chat-layout'
|
||||
|
||||
export type ChatProps = {
|
||||
isTryApp?: boolean
|
||||
@ -133,128 +126,17 @@ const Chat: FC<ChatProps> = ({
|
||||
showAgentLogModal: state.showAgentLogModal,
|
||||
setShowAgentLogModal: state.setShowAgentLogModal,
|
||||
})))
|
||||
const [width, setWidth] = useState(0)
|
||||
const chatContainerRef = useRef<HTMLDivElement>(null)
|
||||
const chatContainerInnerRef = useRef<HTMLDivElement>(null)
|
||||
const chatFooterRef = useRef<HTMLDivElement>(null)
|
||||
const chatFooterInnerRef = useRef<HTMLDivElement>(null)
|
||||
const userScrolledRef = useRef(false)
|
||||
const isAutoScrollingRef = useRef(false)
|
||||
|
||||
const handleScrollToBottom = useCallback(() => {
|
||||
if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current) {
|
||||
isAutoScrollingRef.current = true
|
||||
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
isAutoScrollingRef.current = false
|
||||
})
|
||||
}
|
||||
}, [chatList.length])
|
||||
|
||||
const handleWindowResize = useCallback(() => {
|
||||
if (chatContainerRef.current)
|
||||
setWidth(document.body.clientWidth - (chatContainerRef.current?.clientWidth + 16) - 8)
|
||||
|
||||
if (chatContainerRef.current && chatFooterRef.current)
|
||||
chatFooterRef.current.style.width = `${chatContainerRef.current.clientWidth}px`
|
||||
|
||||
if (chatContainerInnerRef.current && chatFooterInnerRef.current)
|
||||
chatFooterInnerRef.current.style.width = `${chatContainerInnerRef.current.clientWidth}px`
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
handleScrollToBottom()
|
||||
handleWindowResize()
|
||||
}, [handleScrollToBottom, handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
/* v8 ignore next - @preserve */
|
||||
if (chatContainerRef.current) {
|
||||
requestAnimationFrame(() => {
|
||||
handleScrollToBottom()
|
||||
handleWindowResize()
|
||||
})
|
||||
}
|
||||
const {
|
||||
width,
|
||||
chatContainerRef,
|
||||
chatContainerInnerRef,
|
||||
chatFooterRef,
|
||||
chatFooterInnerRef,
|
||||
} = useChatLayout({
|
||||
chatList,
|
||||
sidebarCollapseState,
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
const debouncedHandler = debounce(handleWindowResize, 200)
|
||||
window.addEventListener('resize', debouncedHandler)
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('resize', debouncedHandler)
|
||||
debouncedHandler.cancel()
|
||||
}
|
||||
}, [handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
/* v8 ignore next - @preserve */
|
||||
if (chatFooterRef.current && chatContainerRef.current) {
|
||||
const resizeContainerObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const { blockSize } = entry.borderBoxSize[0]
|
||||
chatContainerRef.current!.style.paddingBottom = `${blockSize}px`
|
||||
handleScrollToBottom()
|
||||
}
|
||||
})
|
||||
resizeContainerObserver.observe(chatFooterRef.current)
|
||||
|
||||
const resizeFooterObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const { inlineSize } = entry.borderBoxSize[0]
|
||||
chatFooterRef.current!.style.width = `${inlineSize}px`
|
||||
}
|
||||
})
|
||||
resizeFooterObserver.observe(chatContainerRef.current)
|
||||
|
||||
return () => {
|
||||
resizeContainerObserver.disconnect()
|
||||
resizeFooterObserver.disconnect()
|
||||
}
|
||||
}
|
||||
}, [handleScrollToBottom])
|
||||
|
||||
useEffect(() => {
|
||||
const setUserScrolled = () => {
|
||||
const container = chatContainerRef.current
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (!container)
|
||||
return
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (isAutoScrollingRef.current)
|
||||
return
|
||||
|
||||
const distanceToBottom = container.scrollHeight - container.clientHeight - container.scrollTop
|
||||
const SCROLL_UP_THRESHOLD = 100
|
||||
|
||||
userScrolledRef.current = distanceToBottom > SCROLL_UP_THRESHOLD
|
||||
}
|
||||
|
||||
const container = chatContainerRef.current
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (!container)
|
||||
return
|
||||
|
||||
container.addEventListener('scroll', setUserScrolled)
|
||||
return () => container.removeEventListener('scroll', setUserScrolled)
|
||||
}, [])
|
||||
|
||||
const prevFirstMessageIdRef = useRef<string | undefined>(undefined)
|
||||
useEffect(() => {
|
||||
const firstMessageId = chatList[0]?.id
|
||||
if (chatList.length <= 1 || (firstMessageId && prevFirstMessageIdRef.current !== firstMessageId))
|
||||
userScrolledRef.current = false
|
||||
prevFirstMessageIdRef.current = firstMessageId
|
||||
}, [chatList])
|
||||
|
||||
useEffect(() => {
|
||||
if (!sidebarCollapseState) {
|
||||
const timer = setTimeout(handleWindowResize, 200)
|
||||
return () => clearTimeout(timer)
|
||||
}
|
||||
}, [handleWindowResize, sidebarCollapseState])
|
||||
|
||||
const hasTryToAsk = config?.suggested_questions_after_answer?.enabled && !!suggestedQuestions?.length && onSend
|
||||
|
||||
return (
|
||||
@ -279,7 +161,7 @@ const Chat: FC<ChatProps> = ({
|
||||
<div
|
||||
data-testid="chat-container"
|
||||
ref={chatContainerRef}
|
||||
className={cn('relative h-full overflow-y-auto overflow-x-hidden', isTryApp && 'h-0 grow', chatContainerClassName)}
|
||||
className={cn('relative h-full overflow-x-hidden overflow-y-auto', isTryApp && 'h-0 grow', chatContainerClassName)}
|
||||
>
|
||||
{chatNode}
|
||||
<div
|
||||
@ -338,7 +220,7 @@ const Chat: FC<ChatProps> = ({
|
||||
!noStopResponding && isResponding && (
|
||||
<div data-testid="stop-responding-container" className="mb-2 flex justify-center">
|
||||
<Button className="border-components-panel-border bg-components-panel-bg text-components-button-secondary-text" onClick={onStopResponding}>
|
||||
<div className="i-custom-vender-solid-mediaAndDevices-stop-circle mr-[5px] h-3.5 w-3.5" />
|
||||
<div className="mr-[5px] i-custom-vender-solid-mediaAndDevices-stop-circle h-3.5 w-3.5" />
|
||||
<span className="text-xs font-normal">{t('operation.stopResponding', { ns: 'appDebug' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
@ -375,26 +257,16 @@ const Chat: FC<ChatProps> = ({
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
{showPromptLogModal && !hideLogModal && (
|
||||
<PromptLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowPromptLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{showAgentLogModal && !hideLogModal && (
|
||||
<AgentLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowAgentLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<ChatLogModals
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
showPromptLogModal={showPromptLogModal}
|
||||
showAgentLogModal={showAgentLogModal}
|
||||
hideLogModal={hideLogModal}
|
||||
setCurrentLogItem={setCurrentLogItem}
|
||||
setShowPromptLogModal={setShowPromptLogModal}
|
||||
setShowAgentLogModal={setShowAgentLogModal}
|
||||
/>
|
||||
</div>
|
||||
</ChatContextProvider>
|
||||
)
|
||||
|
||||
144
web/app/components/base/chat/chat/use-chat-layout.ts
Normal file
144
web/app/components/base/chat/chat/use-chat-layout.ts
Normal file
@ -0,0 +1,144 @@
|
||||
import type { ChatItem } from '../types'
|
||||
import { debounce } from 'es-toolkit/compat'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
||||
type UseChatLayoutOptions = {
|
||||
chatList: ChatItem[]
|
||||
sidebarCollapseState?: boolean
|
||||
}
|
||||
|
||||
export const useChatLayout = ({ chatList, sidebarCollapseState }: UseChatLayoutOptions) => {
|
||||
const [width, setWidth] = useState(0)
|
||||
const chatContainerRef = useRef<HTMLDivElement>(null)
|
||||
const chatContainerInnerRef = useRef<HTMLDivElement>(null)
|
||||
const chatFooterRef = useRef<HTMLDivElement>(null)
|
||||
const chatFooterInnerRef = useRef<HTMLDivElement>(null)
|
||||
const userScrolledRef = useRef(false)
|
||||
const isAutoScrollingRef = useRef(false)
|
||||
const prevFirstMessageIdRef = useRef<string | undefined>(undefined)
|
||||
|
||||
const handleScrollToBottom = useCallback(() => {
|
||||
if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current) {
|
||||
isAutoScrollingRef.current = true
|
||||
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
isAutoScrollingRef.current = false
|
||||
})
|
||||
}
|
||||
}, [chatList.length])
|
||||
|
||||
const handleWindowResize = useCallback(() => {
|
||||
if (chatContainerRef.current)
|
||||
setWidth(document.body.clientWidth - (chatContainerRef.current.clientWidth + 16) - 8)
|
||||
|
||||
if (chatContainerRef.current && chatFooterRef.current)
|
||||
chatFooterRef.current.style.width = `${chatContainerRef.current.clientWidth}px`
|
||||
|
||||
if (chatContainerInnerRef.current && chatFooterInnerRef.current)
|
||||
chatFooterInnerRef.current.style.width = `${chatContainerInnerRef.current.clientWidth}px`
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
handleScrollToBottom()
|
||||
const animationFrame = requestAnimationFrame(handleWindowResize)
|
||||
|
||||
return () => {
|
||||
cancelAnimationFrame(animationFrame)
|
||||
}
|
||||
}, [handleScrollToBottom, handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
if (chatContainerRef.current) {
|
||||
requestAnimationFrame(() => {
|
||||
handleScrollToBottom()
|
||||
handleWindowResize()
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
const debouncedHandler = debounce(handleWindowResize, 200)
|
||||
window.addEventListener('resize', debouncedHandler)
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('resize', debouncedHandler)
|
||||
debouncedHandler.cancel()
|
||||
}
|
||||
}, [handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
if (chatFooterRef.current && chatContainerRef.current) {
|
||||
const resizeContainerObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const { blockSize } = entry.borderBoxSize[0]
|
||||
chatContainerRef.current!.style.paddingBottom = `${blockSize}px`
|
||||
handleScrollToBottom()
|
||||
}
|
||||
})
|
||||
resizeContainerObserver.observe(chatFooterRef.current)
|
||||
|
||||
const resizeFooterObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const { inlineSize } = entry.borderBoxSize[0]
|
||||
chatFooterRef.current!.style.width = `${inlineSize}px`
|
||||
}
|
||||
})
|
||||
resizeFooterObserver.observe(chatContainerRef.current)
|
||||
|
||||
return () => {
|
||||
resizeContainerObserver.disconnect()
|
||||
resizeFooterObserver.disconnect()
|
||||
}
|
||||
}
|
||||
}, [handleScrollToBottom])
|
||||
|
||||
useEffect(() => {
|
||||
const setUserScrolled = () => {
|
||||
const container = chatContainerRef.current
|
||||
if (!container)
|
||||
return
|
||||
if (isAutoScrollingRef.current)
|
||||
return
|
||||
|
||||
const distanceToBottom = container.scrollHeight - container.clientHeight - container.scrollTop
|
||||
const scrollUpThreshold = 100
|
||||
|
||||
userScrolledRef.current = distanceToBottom > scrollUpThreshold
|
||||
}
|
||||
|
||||
const container = chatContainerRef.current
|
||||
if (!container)
|
||||
return
|
||||
|
||||
container.addEventListener('scroll', setUserScrolled)
|
||||
return () => container.removeEventListener('scroll', setUserScrolled)
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const firstMessageId = chatList[0]?.id
|
||||
if (chatList.length <= 1 || (firstMessageId && prevFirstMessageIdRef.current !== firstMessageId))
|
||||
userScrolledRef.current = false
|
||||
prevFirstMessageIdRef.current = firstMessageId
|
||||
}, [chatList])
|
||||
|
||||
useEffect(() => {
|
||||
if (!sidebarCollapseState) {
|
||||
const timer = setTimeout(handleWindowResize, 200)
|
||||
return () => clearTimeout(timer)
|
||||
}
|
||||
}, [handleWindowResize, sidebarCollapseState])
|
||||
|
||||
return {
|
||||
width,
|
||||
chatContainerRef,
|
||||
chatContainerInnerRef,
|
||||
chatFooterRef,
|
||||
chatFooterInnerRef,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,113 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { NotionPageRow } from '../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import PageRow from '../page-row'
|
||||
|
||||
const buildRow = (overrides: Partial<NotionPageRow> = {}): NotionPageRow => ({
|
||||
page: {
|
||||
page_id: 'page-1',
|
||||
page_name: 'Page 1',
|
||||
parent_id: 'root',
|
||||
page_icon: null,
|
||||
type: 'page',
|
||||
is_bound: false,
|
||||
},
|
||||
parentExists: false,
|
||||
depth: 0,
|
||||
expand: false,
|
||||
hasChild: false,
|
||||
ancestors: [],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderPageRow = (overrides: Partial<ComponentProps<typeof PageRow>> = {}) => {
|
||||
const props: ComponentProps<typeof PageRow> = {
|
||||
checked: false,
|
||||
disabled: false,
|
||||
isPreviewed: false,
|
||||
onPreview: vi.fn(),
|
||||
onSelect: vi.fn(),
|
||||
onToggle: vi.fn(),
|
||||
row: buildRow(),
|
||||
searchValue: '',
|
||||
selectionMode: 'multiple',
|
||||
showPreview: true,
|
||||
style: { height: 28 },
|
||||
...overrides,
|
||||
}
|
||||
|
||||
return {
|
||||
...render(<PageRow {...props} />),
|
||||
props,
|
||||
}
|
||||
}
|
||||
|
||||
describe('PageRow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should call onSelect with the page id when the checkbox is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
renderPageRow({ onSelect })
|
||||
|
||||
await user.click(screen.getByTestId('checkbox-notion-page-checkbox-page-1'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith('page-1')
|
||||
})
|
||||
|
||||
it('should call onToggle when the row has children and the toggle is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onToggle = vi.fn()
|
||||
|
||||
renderPageRow({
|
||||
onToggle,
|
||||
row: buildRow({
|
||||
hasChild: true,
|
||||
expand: true,
|
||||
}),
|
||||
})
|
||||
|
||||
await user.click(screen.getByTestId('notion-page-toggle-page-1'))
|
||||
|
||||
expect(onToggle).toHaveBeenCalledWith('page-1')
|
||||
})
|
||||
|
||||
it('should render breadcrumbs and hide the toggle while searching', () => {
|
||||
renderPageRow({
|
||||
searchValue: 'Page',
|
||||
row: buildRow({
|
||||
parentExists: true,
|
||||
ancestors: ['Workspace', 'Section'],
|
||||
}),
|
||||
})
|
||||
|
||||
expect(screen.queryByTestId('notion-page-toggle-page-1')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Workspace / Section / Page 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render preview state and call onPreview when the preview button is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onPreview = vi.fn()
|
||||
|
||||
renderPageRow({
|
||||
isPreviewed: true,
|
||||
onPreview,
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('notion-page-row-page-1')).toHaveClass('bg-state-base-hover')
|
||||
|
||||
await user.click(screen.getByTestId('notion-page-preview-page-1'))
|
||||
|
||||
expect(onPreview).toHaveBeenCalledWith('page-1')
|
||||
})
|
||||
|
||||
it('should hide the preview button when showPreview is false', () => {
|
||||
renderPageRow({ showPreview: false })
|
||||
|
||||
expect(screen.queryByTestId('notion-page-preview-page-1')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,127 @@
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { usePageSelectorModel } from '../use-page-selector-model'
|
||||
|
||||
const buildPage = (overrides: Partial<DataSourceNotionPage>): DataSourceNotionPage => ({
|
||||
page_id: 'page-id',
|
||||
page_name: 'Page name',
|
||||
parent_id: 'root',
|
||||
page_icon: null,
|
||||
type: 'page',
|
||||
is_bound: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const list: DataSourceNotionPage[] = [
|
||||
buildPage({ page_id: 'root-1', page_name: 'Root 1', parent_id: 'root' }),
|
||||
buildPage({ page_id: 'child-1', page_name: 'Child 1', parent_id: 'root-1' }),
|
||||
buildPage({ page_id: 'grandchild-1', page_name: 'Grandchild 1', parent_id: 'child-1' }),
|
||||
buildPage({ page_id: 'child-2', page_name: 'Child 2', parent_id: 'root-1' }),
|
||||
]
|
||||
|
||||
const pagesMap: DataSourceNotionPageMap = {
|
||||
'root-1': { ...list[0], workspace_id: 'workspace-1' },
|
||||
'child-1': { ...list[1], workspace_id: 'workspace-1' },
|
||||
'grandchild-1': { ...list[2], workspace_id: 'workspace-1' },
|
||||
'child-2': { ...list[3], workspace_id: 'workspace-1' },
|
||||
}
|
||||
|
||||
const createProps = (
|
||||
overrides: Partial<Parameters<typeof usePageSelectorModel>[0]> = {},
|
||||
): Parameters<typeof usePageSelectorModel>[0] => ({
|
||||
checkedIds: new Set<string>(),
|
||||
searchValue: '',
|
||||
pagesMap,
|
||||
list,
|
||||
onSelect: vi.fn(),
|
||||
previewPageId: undefined,
|
||||
onPreview: vi.fn(),
|
||||
selectionMode: 'multiple',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('usePageSelectorModel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should build visible rows from the expanded tree state', async () => {
|
||||
const { result } = renderHook(() => usePageSelectorModel(createProps()))
|
||||
|
||||
expect(result.current.rows.map(row => row.page.page_id)).toEqual(['root-1'])
|
||||
|
||||
act(() => {
|
||||
result.current.handleToggle('root-1')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.rows.map(row => row.page.page_id)).toEqual([
|
||||
'root-1',
|
||||
'child-1',
|
||||
'child-2',
|
||||
])
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.handleToggle('child-1')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.rows.map(row => row.page.page_id)).toEqual([
|
||||
'root-1',
|
||||
'child-1',
|
||||
'grandchild-1',
|
||||
'child-2',
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it('should select descendants when selecting a parent in multiple mode', () => {
|
||||
const onSelect = vi.fn()
|
||||
const { result } = renderHook(() => usePageSelectorModel(createProps({ onSelect })))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSelect('root-1')
|
||||
})
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(new Set([
|
||||
'root-1',
|
||||
'child-1',
|
||||
'grandchild-1',
|
||||
'child-2',
|
||||
]))
|
||||
})
|
||||
|
||||
it('should update local preview and respect the controlled previewPageId when provided', () => {
|
||||
const onPreview = vi.fn()
|
||||
const { result, rerender } = renderHook(
|
||||
props => usePageSelectorModel(props),
|
||||
{ initialProps: createProps({ onPreview }) },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handlePreview('child-1')
|
||||
})
|
||||
|
||||
expect(onPreview).toHaveBeenCalledWith('child-1')
|
||||
expect(result.current.currentPreviewPageId).toBe('child-1')
|
||||
|
||||
rerender(createProps({ onPreview, previewPageId: 'grandchild-1' }))
|
||||
|
||||
expect(result.current.currentPreviewPageId).toBe('grandchild-1')
|
||||
})
|
||||
|
||||
it('should expose filtered rows when the deferred search value changes', async () => {
|
||||
const { result, rerender } = renderHook(
|
||||
props => usePageSelectorModel(props),
|
||||
{ initialProps: createProps() },
|
||||
)
|
||||
|
||||
rerender(createProps({ searchValue: 'Grandchild' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.effectiveSearchValue).toBe('Grandchild')
|
||||
expect(result.current.rows.map(row => row.page.page_id)).toEqual(['grandchild-1'])
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,118 @@
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import {
|
||||
buildNotionPageTree,
|
||||
getNextSelectedPageIds,
|
||||
getRootPageIds,
|
||||
getVisiblePageRows,
|
||||
} from '../utils'
|
||||
|
||||
const buildPage = (overrides: Partial<DataSourceNotionPage>): DataSourceNotionPage => ({
|
||||
page_id: 'page-id',
|
||||
page_name: 'Page name',
|
||||
parent_id: 'root',
|
||||
page_icon: null,
|
||||
type: 'page',
|
||||
is_bound: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const list: DataSourceNotionPage[] = [
|
||||
buildPage({ page_id: 'root-1', page_name: 'Root 1', parent_id: 'root' }),
|
||||
buildPage({ page_id: 'child-1', page_name: 'Child 1', parent_id: 'root-1' }),
|
||||
buildPage({ page_id: 'grandchild-1', page_name: 'Grandchild 1', parent_id: 'child-1' }),
|
||||
buildPage({ page_id: 'child-2', page_name: 'Child 2', parent_id: 'root-1' }),
|
||||
buildPage({ page_id: 'orphan-1', page_name: 'Orphan 1', parent_id: 'missing-parent' }),
|
||||
]
|
||||
|
||||
const pagesMap: DataSourceNotionPageMap = {
|
||||
'root-1': { ...list[0], workspace_id: 'workspace-1' },
|
||||
'child-1': { ...list[1], workspace_id: 'workspace-1' },
|
||||
'grandchild-1': { ...list[2], workspace_id: 'workspace-1' },
|
||||
'child-2': { ...list[3], workspace_id: 'workspace-1' },
|
||||
'orphan-1': { ...list[4], workspace_id: 'workspace-1' },
|
||||
}
|
||||
|
||||
describe('page-selector utils', () => {
|
||||
it('should build a tree with descendants, depth, and ancestors', () => {
|
||||
const treeMap = buildNotionPageTree(list, pagesMap)
|
||||
|
||||
expect(treeMap['root-1'].children).toEqual(new Set(['child-1', 'child-2']))
|
||||
expect(treeMap['root-1'].descendants).toEqual(new Set(['child-1', 'grandchild-1', 'child-2']))
|
||||
expect(treeMap['grandchild-1'].depth).toBe(2)
|
||||
expect(treeMap['grandchild-1'].ancestors).toEqual(['Root 1', 'Child 1'])
|
||||
})
|
||||
|
||||
it('should return root page ids for true roots and pages with missing parents', () => {
|
||||
expect(getRootPageIds(list, pagesMap)).toEqual(['root-1', 'orphan-1'])
|
||||
})
|
||||
|
||||
it('should return expanded tree rows in depth-first order when not searching', () => {
|
||||
const treeMap = buildNotionPageTree(list, pagesMap)
|
||||
|
||||
const rows = getVisiblePageRows({
|
||||
list,
|
||||
pagesMap,
|
||||
searchValue: '',
|
||||
treeMap,
|
||||
rootPageIds: ['root-1'],
|
||||
expandedIds: new Set(['root-1', 'child-1']),
|
||||
})
|
||||
|
||||
expect(rows.map(row => row.page.page_id)).toEqual([
|
||||
'root-1',
|
||||
'child-1',
|
||||
'grandchild-1',
|
||||
'child-2',
|
||||
])
|
||||
})
|
||||
|
||||
it('should return filtered search rows with ancestry metadata when searching', () => {
|
||||
const treeMap = buildNotionPageTree(list, pagesMap)
|
||||
|
||||
const rows = getVisiblePageRows({
|
||||
list,
|
||||
pagesMap,
|
||||
searchValue: 'Grandchild',
|
||||
treeMap,
|
||||
rootPageIds: ['root-1'],
|
||||
expandedIds: new Set<string>(),
|
||||
})
|
||||
|
||||
expect(rows).toEqual([
|
||||
expect.objectContaining({
|
||||
page: expect.objectContaining({ page_id: 'grandchild-1' }),
|
||||
ancestors: ['Root 1', 'Child 1'],
|
||||
hasChild: false,
|
||||
parentExists: true,
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
it('should toggle selected ids correctly in single and multiple mode', () => {
|
||||
const treeMap = buildNotionPageTree(list, pagesMap)
|
||||
|
||||
expect(getNextSelectedPageIds({
|
||||
checkedIds: new Set(['root-1']),
|
||||
pageId: 'child-1',
|
||||
searchValue: '',
|
||||
selectionMode: 'single',
|
||||
treeMap,
|
||||
})).toEqual(new Set(['child-1']))
|
||||
|
||||
expect(getNextSelectedPageIds({
|
||||
checkedIds: new Set<string>(),
|
||||
pageId: 'root-1',
|
||||
searchValue: '',
|
||||
selectionMode: 'multiple',
|
||||
treeMap,
|
||||
})).toEqual(new Set(['root-1', 'child-1', 'grandchild-1', 'child-2']))
|
||||
|
||||
expect(getNextSelectedPageIds({
|
||||
checkedIds: new Set(['child-1']),
|
||||
pageId: 'child-1',
|
||||
searchValue: 'Child',
|
||||
selectionMode: 'multiple',
|
||||
treeMap,
|
||||
})).toEqual(new Set<string>())
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,144 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { NotionPageRow } from '../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import VirtualPageList from '../virtual-page-list'
|
||||
|
||||
vi.mock('@tanstack/react-virtual')
|
||||
|
||||
const pageRowPropsSpy = vi.fn()
|
||||
type MockPageRowProps = ComponentProps<typeof import('../page-row').default>
|
||||
|
||||
vi.mock('../page-row', () => ({
|
||||
default: ({
|
||||
checked,
|
||||
disabled,
|
||||
isPreviewed,
|
||||
onPreview,
|
||||
onSelect,
|
||||
onToggle,
|
||||
row,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
showPreview,
|
||||
style,
|
||||
}: MockPageRowProps) => {
|
||||
pageRowPropsSpy({
|
||||
checked,
|
||||
disabled,
|
||||
isPreviewed,
|
||||
onPreview,
|
||||
onSelect,
|
||||
onToggle,
|
||||
row,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
showPreview,
|
||||
style,
|
||||
})
|
||||
return <div data-testid={`page-row-${row.page.page_id}`} />
|
||||
},
|
||||
}))
|
||||
|
||||
const buildRow = (overrides: Partial<NotionPageRow> = {}): NotionPageRow => ({
|
||||
page: {
|
||||
page_id: 'page-1',
|
||||
page_name: 'Page 1',
|
||||
parent_id: 'root',
|
||||
page_icon: null,
|
||||
type: 'page',
|
||||
is_bound: false,
|
||||
},
|
||||
parentExists: false,
|
||||
depth: 0,
|
||||
expand: false,
|
||||
hasChild: false,
|
||||
ancestors: [],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('VirtualPageList', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render virtual rows and pass row state to PageRow', () => {
|
||||
const rows = [
|
||||
buildRow(),
|
||||
buildRow({
|
||||
page: {
|
||||
page_id: 'page-2',
|
||||
page_name: 'Page 2',
|
||||
parent_id: 'root',
|
||||
page_icon: null,
|
||||
type: 'page',
|
||||
is_bound: false,
|
||||
},
|
||||
}),
|
||||
]
|
||||
|
||||
render(
|
||||
<VirtualPageList
|
||||
checkedIds={new Set(['page-1'])}
|
||||
disabledValue={new Set(['page-2'])}
|
||||
onPreview={vi.fn()}
|
||||
onSelect={vi.fn()}
|
||||
onToggle={vi.fn()}
|
||||
previewPageId="page-2"
|
||||
rows={rows}
|
||||
searchValue=""
|
||||
selectionMode="multiple"
|
||||
showPreview
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('page-row-page-1')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('page-row-page-2')).toBeInTheDocument()
|
||||
expect(pageRowPropsSpy).toHaveBeenNthCalledWith(1, expect.objectContaining({
|
||||
checked: true,
|
||||
disabled: false,
|
||||
isPreviewed: false,
|
||||
searchValue: '',
|
||||
selectionMode: 'multiple',
|
||||
showPreview: true,
|
||||
row: rows[0],
|
||||
style: expect.objectContaining({
|
||||
height: '28px',
|
||||
width: 'calc(100% - 16px)',
|
||||
}),
|
||||
}))
|
||||
expect(pageRowPropsSpy).toHaveBeenNthCalledWith(2, expect.objectContaining({
|
||||
checked: false,
|
||||
disabled: true,
|
||||
isPreviewed: true,
|
||||
row: rows[1],
|
||||
}))
|
||||
})
|
||||
|
||||
it('should size the virtual container using the row estimate', () => {
|
||||
const rows = [buildRow(), buildRow()]
|
||||
|
||||
render(
|
||||
<VirtualPageList
|
||||
checkedIds={new Set<string>()}
|
||||
disabledValue={new Set<string>()}
|
||||
onPreview={vi.fn()}
|
||||
onSelect={vi.fn()}
|
||||
onToggle={vi.fn()}
|
||||
previewPageId=""
|
||||
rows={rows}
|
||||
searchValue=""
|
||||
selectionMode="multiple"
|
||||
showPreview={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
const list = screen.getByTestId('virtual-list')
|
||||
const innerContainer = list.firstElementChild as HTMLElement
|
||||
|
||||
expect(innerContainer).toHaveStyle({
|
||||
height: '56px',
|
||||
position: 'relative',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,295 @@
|
||||
import type { EventEmitter } from 'ahooks/lib/useEventEmitter'
|
||||
import type { LexicalEditor } from 'lexical'
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { EventEmitterValue } from '@/context/event-emitter'
|
||||
import { CodeNode } from '@lexical/code'
|
||||
import { LexicalComposer } from '@lexical/react/LexicalComposer'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import {
|
||||
BLUR_COMMAND,
|
||||
COMMAND_PRIORITY_EDITOR,
|
||||
createCommand,
|
||||
FOCUS_COMMAND,
|
||||
TextNode,
|
||||
} from 'lexical'
|
||||
import { useEffect } from 'react'
|
||||
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter-provider'
|
||||
import { ContextBlockNode } from '../plugins/context-block'
|
||||
import { CurrentBlockNode } from '../plugins/current-block'
|
||||
import { CustomTextNode } from '../plugins/custom-text/node'
|
||||
import { ErrorMessageBlockNode } from '../plugins/error-message-block'
|
||||
import { HistoryBlockNode } from '../plugins/history-block'
|
||||
import { HITLInputNode } from '../plugins/hitl-input-block'
|
||||
import { LastRunBlockNode } from '../plugins/last-run-block'
|
||||
import { QueryBlockNode } from '../plugins/query-block'
|
||||
import { RequestURLBlockNode } from '../plugins/request-url-block'
|
||||
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '../plugins/update-block'
|
||||
import { VariableValueBlockNode } from '../plugins/variable-value-block/node'
|
||||
import { WorkflowVariableBlockNode } from '../plugins/workflow-variable-block'
|
||||
import PromptEditorContent from '../prompt-editor-content'
|
||||
import { textToEditorState } from '../utils'
|
||||
|
||||
type Captures = {
|
||||
editor: LexicalEditor | null
|
||||
eventEmitter: EventEmitter<EventEmitterValue> | null
|
||||
}
|
||||
|
||||
const mockDOMRect = {
|
||||
x: 100,
|
||||
y: 100,
|
||||
width: 100,
|
||||
height: 20,
|
||||
top: 100,
|
||||
right: 200,
|
||||
bottom: 120,
|
||||
left: 100,
|
||||
toJSON: () => ({}),
|
||||
}
|
||||
|
||||
const originalRangeGetClientRects = Range.prototype.getClientRects
|
||||
const originalRangeGetBoundingClientRect = Range.prototype.getBoundingClientRect
|
||||
|
||||
const setSelectionOnEditable = (editable: HTMLElement) => {
|
||||
const lexicalTextNode = editable.querySelector('[data-lexical-text="true"]')?.firstChild
|
||||
const range = document.createRange()
|
||||
|
||||
if (lexicalTextNode) {
|
||||
range.setStart(lexicalTextNode, 0)
|
||||
range.setEnd(lexicalTextNode, 1)
|
||||
}
|
||||
else {
|
||||
range.selectNodeContents(editable)
|
||||
}
|
||||
|
||||
const selection = window.getSelection()
|
||||
selection?.removeAllRanges()
|
||||
selection?.addRange(range)
|
||||
}
|
||||
|
||||
const CaptureEditorAndEmitter = ({ captures }: { captures: Captures }) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
captures.editor = editor
|
||||
}, [captures, editor])
|
||||
|
||||
useEffect(() => {
|
||||
captures.eventEmitter = eventEmitter
|
||||
}, [captures, eventEmitter])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
const PromptEditorContentHarness = ({
|
||||
captures,
|
||||
initialText = '',
|
||||
...props
|
||||
}: ComponentProps<typeof PromptEditorContent> & { captures: Captures, initialText?: string }) => (
|
||||
<EventEmitterContextProvider>
|
||||
<LexicalComposer
|
||||
initialConfig={{
|
||||
namespace: 'prompt-editor-content-test',
|
||||
editable: true,
|
||||
nodes: [
|
||||
CodeNode,
|
||||
CustomTextNode,
|
||||
{
|
||||
replace: TextNode,
|
||||
with: (node: TextNode) => new CustomTextNode(node.__text),
|
||||
withKlass: CustomTextNode,
|
||||
},
|
||||
ContextBlockNode,
|
||||
HistoryBlockNode,
|
||||
QueryBlockNode,
|
||||
RequestURLBlockNode,
|
||||
WorkflowVariableBlockNode,
|
||||
VariableValueBlockNode,
|
||||
HITLInputNode,
|
||||
CurrentBlockNode,
|
||||
ErrorMessageBlockNode,
|
||||
LastRunBlockNode,
|
||||
],
|
||||
editorState: textToEditorState(initialText),
|
||||
onError: (error: Error) => {
|
||||
throw error
|
||||
},
|
||||
}}
|
||||
>
|
||||
<CaptureEditorAndEmitter captures={captures} />
|
||||
<PromptEditorContent {...props} />
|
||||
</LexicalComposer>
|
||||
</EventEmitterContextProvider>
|
||||
)
|
||||
|
||||
describe('PromptEditorContent', () => {
|
||||
beforeAll(() => {
|
||||
Range.prototype.getClientRects = vi.fn(() => {
|
||||
const rectList = [mockDOMRect] as unknown as DOMRectList
|
||||
Object.defineProperty(rectList, 'length', { value: 1 })
|
||||
Object.defineProperty(rectList, 'item', { value: (index: number) => index === 0 ? mockDOMRect : null })
|
||||
return rectList
|
||||
})
|
||||
Range.prototype.getBoundingClientRect = vi.fn(() => mockDOMRect as DOMRect)
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
Range.prototype.getClientRects = originalRangeGetClientRects
|
||||
Range.prototype.getBoundingClientRect = originalRangeGetBoundingClientRect
|
||||
})
|
||||
|
||||
// The extracted content shell should run with the real Lexical stack and forward editor commands through its composed plugins.
|
||||
describe('Rendering', () => {
|
||||
it('should render with real dependencies and forward update/focus/blur events', async () => {
|
||||
const captures: Captures = { editor: null, eventEmitter: null }
|
||||
const onEditorChange = vi.fn()
|
||||
const onFocus = vi.fn()
|
||||
const onBlur = vi.fn()
|
||||
const anchorElem = document.createElement('div')
|
||||
|
||||
const { container } = render(
|
||||
<PromptEditorContentHarness
|
||||
captures={captures}
|
||||
compact={true}
|
||||
className="editor-shell"
|
||||
placeholder="Type prompt"
|
||||
shortcutPopups={[]}
|
||||
instanceId="content-editor"
|
||||
floatingAnchorElem={anchorElem}
|
||||
onEditorChange={onEditorChange}
|
||||
onFocus={onFocus}
|
||||
onBlur={onBlur}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Type prompt')).toBeInTheDocument()
|
||||
|
||||
const editable = container.querySelector('[contenteditable="true"]') as HTMLElement
|
||||
expect(editable.className).toContain('text-[13px]')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.editor).not.toBeNull()
|
||||
expect(captures.eventEmitter).not.toBeNull()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
captures.eventEmitter?.emit({
|
||||
type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
|
||||
instanceId: 'content-editor',
|
||||
payload: 'updated prompt',
|
||||
})
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onEditorChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
captures.editor?.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus'))
|
||||
captures.editor?.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: null }))
|
||||
})
|
||||
|
||||
expect(onFocus).toHaveBeenCalledTimes(1)
|
||||
expect(onBlur).toHaveBeenCalledTimes(1)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render optional blocks and open shortcut popups with the real editor runtime', async () => {
|
||||
const captures: Captures = { editor: null, eventEmitter: null }
|
||||
const onEditorChange = vi.fn()
|
||||
const insertCommand = createCommand<string[]>('prompt-editor-shortcut-insert')
|
||||
const insertSpy = vi.fn()
|
||||
const Popup = ({ onClose, onInsert }: { onClose: () => void, onInsert: (command: typeof insertCommand, params: string[]) => void }) => (
|
||||
<>
|
||||
<button type="button" onClick={() => onInsert(insertCommand, ['from-shortcut'])}>Insert shortcut</button>
|
||||
<button type="button" onClick={onClose}>Close shortcut</button>
|
||||
</>
|
||||
)
|
||||
|
||||
const { container } = render(
|
||||
<PromptEditorContentHarness
|
||||
captures={captures}
|
||||
shortcutPopups={[{ hotkey: 'ctrl+/', Popup }]}
|
||||
initialText="seed prompt"
|
||||
floatingAnchorElem={document.createElement('div')}
|
||||
onEditorChange={onEditorChange}
|
||||
contextBlock={{ show: true, datasets: [] }}
|
||||
queryBlock={{ show: true }}
|
||||
requestURLBlock={{ show: true }}
|
||||
historyBlock={{ show: true, history: { user: 'user-role', assistant: 'assistant-role' } }}
|
||||
variableBlock={{ show: true, variables: [] }}
|
||||
externalToolBlock={{ show: true, externalTools: [] }}
|
||||
workflowVariableBlock={{ show: true, variables: [] }}
|
||||
hitlInputBlock={{
|
||||
show: true,
|
||||
nodeId: 'node-1',
|
||||
onFormInputItemRemove: vi.fn(),
|
||||
onFormInputItemRename: vi.fn(),
|
||||
}}
|
||||
currentBlock={{ show: true, generatorType: GeneratorType.prompt }}
|
||||
errorMessageBlock={{ show: true }}
|
||||
lastRunBlock={{ show: true }}
|
||||
isSupportFileVar={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.editor).not.toBeNull()
|
||||
})
|
||||
|
||||
const unregister = captures.editor?.registerCommand(
|
||||
insertCommand,
|
||||
(payload) => {
|
||||
insertSpy(payload)
|
||||
return true
|
||||
},
|
||||
COMMAND_PRIORITY_EDITOR,
|
||||
)
|
||||
|
||||
const editable = container.querySelector('[contenteditable="true"]') as HTMLElement
|
||||
editable.focus()
|
||||
setSelectionOnEditable(editable)
|
||||
|
||||
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
|
||||
|
||||
const insertButton = await screen.findByRole('button', { name: 'Insert shortcut' })
|
||||
fireEvent.click(insertButton)
|
||||
|
||||
expect(insertSpy).toHaveBeenCalledWith(['from-shortcut'])
|
||||
expect(onEditorChange).toHaveBeenCalled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('button', { name: 'Insert shortcut' })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
unregister?.()
|
||||
})
|
||||
|
||||
it('should keep the shell stable without optional anchor or placeholder overrides', async () => {
|
||||
const captures: Captures = { editor: null, eventEmitter: null }
|
||||
|
||||
render(
|
||||
<PromptEditorContentHarness
|
||||
captures={captures}
|
||||
shortcutPopups={[]}
|
||||
floatingAnchorElem={null}
|
||||
onEditorChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(captures.editor).not.toBeNull()
|
||||
})
|
||||
|
||||
expect(screen.queryByTestId('draggable-target-line')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('common.promptEditor.placeholder')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -23,11 +23,6 @@ import type {
|
||||
import { CodeNode } from '@lexical/code'
|
||||
import { LexicalComposer } from '@lexical/react/LexicalComposer'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
|
||||
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
|
||||
import {
|
||||
$getRoot,
|
||||
TextNode,
|
||||
@ -40,63 +35,37 @@ import {
|
||||
UPDATE_DATASETS_EVENT_EMITTER,
|
||||
UPDATE_HISTORY_EVENT_EMITTER,
|
||||
} from './constants'
|
||||
import ComponentPickerBlock from './plugins/component-picker-block'
|
||||
import {
|
||||
ContextBlock,
|
||||
ContextBlockNode,
|
||||
ContextBlockReplacementBlock,
|
||||
} from './plugins/context-block'
|
||||
import {
|
||||
CurrentBlock,
|
||||
CurrentBlockNode,
|
||||
CurrentBlockReplacementBlock,
|
||||
} from './plugins/current-block'
|
||||
import { CustomTextNode } from './plugins/custom-text/node'
|
||||
import DraggableBlockPlugin from './plugins/draggable-plugin'
|
||||
import {
|
||||
ErrorMessageBlock,
|
||||
ErrorMessageBlockNode,
|
||||
ErrorMessageBlockReplacementBlock,
|
||||
} from './plugins/error-message-block'
|
||||
import {
|
||||
HistoryBlock,
|
||||
HistoryBlockNode,
|
||||
HistoryBlockReplacementBlock,
|
||||
} from './plugins/history-block'
|
||||
|
||||
import {
|
||||
HITLInputBlock,
|
||||
HITLInputBlockReplacementBlock,
|
||||
HITLInputNode,
|
||||
} from './plugins/hitl-input-block'
|
||||
import {
|
||||
LastRunBlock,
|
||||
LastRunBlockNode,
|
||||
LastRunReplacementBlock,
|
||||
} from './plugins/last-run-block'
|
||||
import OnBlurBlock from './plugins/on-blur-or-focus-block'
|
||||
// import TreeView from './plugins/tree-view'
|
||||
import Placeholder from './plugins/placeholder'
|
||||
import {
|
||||
QueryBlock,
|
||||
QueryBlockNode,
|
||||
QueryBlockReplacementBlock,
|
||||
} from './plugins/query-block'
|
||||
import {
|
||||
RequestURLBlock,
|
||||
RequestURLBlockNode,
|
||||
RequestURLBlockReplacementBlock,
|
||||
} from './plugins/request-url-block'
|
||||
import ShortcutsPopupPlugin from './plugins/shortcuts-popup-plugin'
|
||||
import UpdateBlock from './plugins/update-block'
|
||||
import VariableBlock from './plugins/variable-block'
|
||||
import VariableValueBlock from './plugins/variable-value-block'
|
||||
import { VariableValueBlockNode } from './plugins/variable-value-block/node'
|
||||
import {
|
||||
WorkflowVariableBlock,
|
||||
WorkflowVariableBlockNode,
|
||||
WorkflowVariableBlockReplacementBlock,
|
||||
} from './plugins/workflow-variable-block'
|
||||
import PromptEditorContent from './prompt-editor-content'
|
||||
import { textToEditorState } from './utils'
|
||||
|
||||
const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => {
|
||||
@ -238,153 +207,31 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
return (
|
||||
<LexicalComposer initialConfig={{ ...initialConfig, editable }}>
|
||||
<div className={cn('relative', wrapperClassName)} ref={onRef}>
|
||||
<RichTextPlugin
|
||||
contentEditable={(
|
||||
<ContentEditable
|
||||
className={cn(
|
||||
'group/editable text-text-secondary outline-hidden group-[.clamp]:max-h-24 group-[.clamp]:overflow-y-auto',
|
||||
compact ? 'text-[13px] leading-5' : 'text-sm leading-6',
|
||||
className,
|
||||
)}
|
||||
style={style || {}}
|
||||
/>
|
||||
)}
|
||||
placeholder={(
|
||||
<Placeholder
|
||||
value={placeholder}
|
||||
className={cn('truncate', placeholderClassName)}
|
||||
compact={compact}
|
||||
/>
|
||||
)}
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
{shortcutPopups?.map(({ hotkey, Popup }, idx) => (
|
||||
<ShortcutsPopupPlugin key={idx} hotkey={hotkey}>
|
||||
{(closePortal, onInsert) => <Popup onClose={closePortal} onInsert={onInsert} />}
|
||||
</ShortcutsPopupPlugin>
|
||||
))}
|
||||
<ComponentPickerBlock
|
||||
triggerString="/"
|
||||
<PromptEditorContent
|
||||
compact={compact}
|
||||
className={className}
|
||||
placeholder={placeholder}
|
||||
placeholderClassName={placeholderClassName}
|
||||
style={style}
|
||||
shortcutPopups={shortcutPopups}
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
requestURLBlock={requestURLBlock}
|
||||
historyBlock={historyBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
hitlInputBlock={hitlInputBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
onBlur={onBlur}
|
||||
onFocus={onFocus}
|
||||
instanceId={instanceId}
|
||||
floatingAnchorElem={floatingAnchorElem}
|
||||
onEditorChange={handleEditorChange}
|
||||
/>
|
||||
<ComponentPickerBlock
|
||||
triggerString="{"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
requestURLBlock={requestURLBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
{
|
||||
contextBlock?.show && (
|
||||
<>
|
||||
<ContextBlock {...contextBlock} />
|
||||
<ContextBlockReplacementBlock {...contextBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
queryBlock?.show && (
|
||||
<>
|
||||
<QueryBlock {...queryBlock} />
|
||||
<QueryBlockReplacementBlock />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
historyBlock?.show && (
|
||||
<>
|
||||
<HistoryBlock {...historyBlock} />
|
||||
<HistoryBlockReplacementBlock {...historyBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
(variableBlock?.show || externalToolBlock?.show) && (
|
||||
<>
|
||||
<VariableBlock />
|
||||
<VariableValueBlock />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && (
|
||||
<>
|
||||
<WorkflowVariableBlock {...workflowVariableBlock} />
|
||||
<WorkflowVariableBlockReplacementBlock {...workflowVariableBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
hitlInputBlock?.show && (
|
||||
<>
|
||||
<HITLInputBlock {...hitlInputBlock} />
|
||||
<HITLInputBlockReplacementBlock {...hitlInputBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
currentBlock?.show && (
|
||||
<>
|
||||
<CurrentBlock {...currentBlock} />
|
||||
<CurrentBlockReplacementBlock {...currentBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
requestURLBlock?.show && (
|
||||
<>
|
||||
<RequestURLBlock {...requestURLBlock} />
|
||||
<RequestURLBlockReplacementBlock {...requestURLBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
errorMessageBlock?.show && (
|
||||
<>
|
||||
<ErrorMessageBlock {...errorMessageBlock} />
|
||||
<ErrorMessageBlockReplacementBlock {...errorMessageBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
lastRunBlock?.show && (
|
||||
<>
|
||||
<LastRunBlock {...lastRunBlock} />
|
||||
<LastRunReplacementBlock {...lastRunBlock} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
isSupportFileVar && (
|
||||
<VariableValueBlock />
|
||||
)
|
||||
}
|
||||
<ValueSyncPlugin value={value} />
|
||||
<OnChangePlugin onChange={handleEditorChange} />
|
||||
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
|
||||
<UpdateBlock instanceId={instanceId} />
|
||||
<HistoryPlugin />
|
||||
{floatingAnchorElem && (
|
||||
<DraggableBlockPlugin anchorElem={floatingAnchorElem} />
|
||||
)}
|
||||
{/* <TreeView /> */}
|
||||
</div>
|
||||
</LexicalComposer>
|
||||
)
|
||||
|
||||
257
web/app/components/base/prompt-editor/prompt-editor-content.tsx
Normal file
257
web/app/components/base/prompt-editor/prompt-editor-content.tsx
Normal file
@ -0,0 +1,257 @@
|
||||
import type {
|
||||
EditorState,
|
||||
LexicalCommand,
|
||||
} from 'lexical'
|
||||
import type { FC } from 'react'
|
||||
import type { Hotkey } from './plugins/shortcuts-popup-plugin'
|
||||
import type {
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
ExternalToolBlockType,
|
||||
HistoryBlockType,
|
||||
HITLInputBlockType,
|
||||
LastRunBlockType,
|
||||
QueryBlockType,
|
||||
RequestURLBlockType,
|
||||
VariableBlockType,
|
||||
WorkflowVariableBlockType,
|
||||
} from './types'
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
|
||||
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
|
||||
import * as React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ComponentPickerBlock from './plugins/component-picker-block'
|
||||
import {
|
||||
ContextBlock,
|
||||
ContextBlockReplacementBlock,
|
||||
} from './plugins/context-block'
|
||||
import {
|
||||
CurrentBlock,
|
||||
CurrentBlockReplacementBlock,
|
||||
} from './plugins/current-block'
|
||||
import DraggableBlockPlugin from './plugins/draggable-plugin'
|
||||
import {
|
||||
ErrorMessageBlock,
|
||||
ErrorMessageBlockReplacementBlock,
|
||||
} from './plugins/error-message-block'
|
||||
import {
|
||||
HistoryBlock,
|
||||
HistoryBlockReplacementBlock,
|
||||
} from './plugins/history-block'
|
||||
import {
|
||||
HITLInputBlock,
|
||||
HITLInputBlockReplacementBlock,
|
||||
} from './plugins/hitl-input-block'
|
||||
import {
|
||||
LastRunBlock,
|
||||
LastRunReplacementBlock,
|
||||
} from './plugins/last-run-block'
|
||||
import OnBlurBlock from './plugins/on-blur-or-focus-block'
|
||||
import Placeholder from './plugins/placeholder'
|
||||
import {
|
||||
QueryBlock,
|
||||
QueryBlockReplacementBlock,
|
||||
} from './plugins/query-block'
|
||||
import {
|
||||
RequestURLBlock,
|
||||
RequestURLBlockReplacementBlock,
|
||||
} from './plugins/request-url-block'
|
||||
import ShortcutsPopupPlugin from './plugins/shortcuts-popup-plugin'
|
||||
import UpdateBlock from './plugins/update-block'
|
||||
import VariableBlock from './plugins/variable-block'
|
||||
import VariableValueBlock from './plugins/variable-value-block'
|
||||
import {
|
||||
WorkflowVariableBlock,
|
||||
WorkflowVariableBlockReplacementBlock,
|
||||
} from './plugins/workflow-variable-block'
|
||||
|
||||
type ShortcutPopup = {
|
||||
hotkey: Hotkey
|
||||
Popup: React.ComponentType<{ onClose: () => void, onInsert: (command: LexicalCommand<unknown>, params: unknown[]) => void }>
|
||||
}
|
||||
|
||||
type PromptEditorContentProps = {
|
||||
compact?: boolean
|
||||
className?: string
|
||||
placeholder?: string | React.ReactNode
|
||||
placeholderClassName?: string
|
||||
style?: React.CSSProperties
|
||||
shortcutPopups: ShortcutPopup[]
|
||||
contextBlock?: ContextBlockType
|
||||
queryBlock?: QueryBlockType
|
||||
requestURLBlock?: RequestURLBlockType
|
||||
historyBlock?: HistoryBlockType
|
||||
variableBlock?: VariableBlockType
|
||||
externalToolBlock?: ExternalToolBlockType
|
||||
workflowVariableBlock?: WorkflowVariableBlockType
|
||||
hitlInputBlock?: HITLInputBlockType
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
isSupportFileVar?: boolean
|
||||
onBlur?: () => void
|
||||
onFocus?: () => void
|
||||
instanceId?: string
|
||||
floatingAnchorElem: HTMLDivElement | null
|
||||
onEditorChange: (editorState: EditorState) => void
|
||||
}
|
||||
|
||||
const PromptEditorContent: FC<PromptEditorContentProps> = ({
|
||||
compact,
|
||||
className,
|
||||
placeholder,
|
||||
placeholderClassName,
|
||||
style,
|
||||
shortcutPopups,
|
||||
contextBlock,
|
||||
queryBlock,
|
||||
requestURLBlock,
|
||||
historyBlock,
|
||||
variableBlock,
|
||||
externalToolBlock,
|
||||
workflowVariableBlock,
|
||||
hitlInputBlock,
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
isSupportFileVar,
|
||||
onBlur,
|
||||
onFocus,
|
||||
instanceId,
|
||||
floatingAnchorElem,
|
||||
onEditorChange,
|
||||
}) => {
|
||||
return (
|
||||
<>
|
||||
<RichTextPlugin
|
||||
contentEditable={(
|
||||
<ContentEditable
|
||||
className={cn(
|
||||
'group/editable text-text-secondary outline-hidden group-[.clamp]:max-h-24 group-[.clamp]:overflow-y-auto',
|
||||
compact ? 'text-[13px] leading-5' : 'text-sm leading-6',
|
||||
className,
|
||||
)}
|
||||
style={style || {}}
|
||||
/>
|
||||
)}
|
||||
placeholder={(
|
||||
<Placeholder
|
||||
value={placeholder}
|
||||
className={cn('truncate', placeholderClassName)}
|
||||
compact={compact}
|
||||
/>
|
||||
)}
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
{shortcutPopups.map(({ hotkey, Popup }, idx) => (
|
||||
<ShortcutsPopupPlugin key={idx} hotkey={hotkey}>
|
||||
{(closePortal, onInsert) => <Popup onClose={closePortal} onInsert={onInsert} />}
|
||||
</ShortcutsPopupPlugin>
|
||||
))}
|
||||
<ComponentPickerBlock
|
||||
triggerString="/"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
requestURLBlock={requestURLBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
<ComponentPickerBlock
|
||||
triggerString="{"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
requestURLBlock={requestURLBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
{contextBlock?.show && (
|
||||
<>
|
||||
<ContextBlock {...contextBlock} />
|
||||
<ContextBlockReplacementBlock {...contextBlock} />
|
||||
</>
|
||||
)}
|
||||
{queryBlock?.show && (
|
||||
<>
|
||||
<QueryBlock {...queryBlock} />
|
||||
<QueryBlockReplacementBlock />
|
||||
</>
|
||||
)}
|
||||
{historyBlock?.show && (
|
||||
<>
|
||||
<HistoryBlock {...historyBlock} />
|
||||
<HistoryBlockReplacementBlock {...historyBlock} />
|
||||
</>
|
||||
)}
|
||||
{(variableBlock?.show || externalToolBlock?.show) && (
|
||||
<>
|
||||
<VariableBlock />
|
||||
<VariableValueBlock />
|
||||
</>
|
||||
)}
|
||||
{workflowVariableBlock?.show && (
|
||||
<>
|
||||
<WorkflowVariableBlock {...workflowVariableBlock} />
|
||||
<WorkflowVariableBlockReplacementBlock {...workflowVariableBlock} />
|
||||
</>
|
||||
)}
|
||||
{hitlInputBlock?.show && (
|
||||
<>
|
||||
<HITLInputBlock {...hitlInputBlock} />
|
||||
<HITLInputBlockReplacementBlock {...hitlInputBlock} />
|
||||
</>
|
||||
)}
|
||||
{currentBlock?.show && (
|
||||
<>
|
||||
<CurrentBlock {...currentBlock} />
|
||||
<CurrentBlockReplacementBlock {...currentBlock} />
|
||||
</>
|
||||
)}
|
||||
{requestURLBlock?.show && (
|
||||
<>
|
||||
<RequestURLBlock {...requestURLBlock} />
|
||||
<RequestURLBlockReplacementBlock {...requestURLBlock} />
|
||||
</>
|
||||
)}
|
||||
{errorMessageBlock?.show && (
|
||||
<>
|
||||
<ErrorMessageBlock {...errorMessageBlock} />
|
||||
<ErrorMessageBlockReplacementBlock {...errorMessageBlock} />
|
||||
</>
|
||||
)}
|
||||
{lastRunBlock?.show && (
|
||||
<>
|
||||
<LastRunBlock {...lastRunBlock} />
|
||||
<LastRunReplacementBlock {...lastRunBlock} />
|
||||
</>
|
||||
)}
|
||||
{isSupportFileVar && (
|
||||
<VariableValueBlock />
|
||||
)}
|
||||
<OnChangePlugin onChange={onEditorChange} />
|
||||
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
|
||||
<UpdateBlock instanceId={instanceId} />
|
||||
<HistoryPlugin />
|
||||
{floatingAnchorElem && (
|
||||
<DraggableBlockPlugin anchorElem={floatingAnchorElem} />
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default PromptEditorContent
|
||||
@ -0,0 +1,30 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import Loading from '../loading'
|
||||
|
||||
vi.mock('@/app/components/base/skeleton', () => ({
|
||||
SkeletonContainer: ({ children, className }: { children?: ReactNode, className?: string }) => (
|
||||
<div data-testid="skeleton-container" className={className}>{children}</div>
|
||||
),
|
||||
SkeletonRectangle: ({ className }: { className?: string }) => (
|
||||
<div data-testid="skeleton-rectangle" className={className} />
|
||||
),
|
||||
}))
|
||||
|
||||
describe('CreateFromPipelinePreviewLoading', () => {
|
||||
it('should render the preview loading shell and all skeleton blocks', () => {
|
||||
const { container } = render(<Loading />)
|
||||
|
||||
expect(container.firstElementChild).toHaveClass(
|
||||
'flex',
|
||||
'h-full',
|
||||
'w-full',
|
||||
'flex-col',
|
||||
'overflow-hidden',
|
||||
'px-6',
|
||||
'py-5',
|
||||
)
|
||||
expect(screen.getAllByTestId('skeleton-container')).toHaveLength(6)
|
||||
expect(screen.getAllByTestId('skeleton-rectangle')).toHaveLength(29)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,30 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { DocumentContext, useDocumentContext } from '../context'
|
||||
|
||||
describe('DocumentContext', () => {
|
||||
it('should return the default empty context value when no provider is present', () => {
|
||||
const { result } = renderHook(() => useDocumentContext(value => value))
|
||||
|
||||
expect(result.current).toEqual({})
|
||||
})
|
||||
|
||||
it('should select values from the nearest provider', () => {
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<DocumentContext.Provider value={{
|
||||
datasetId: 'dataset-1',
|
||||
documentId: 'document-1',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</DocumentContext.Provider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentContext(value => `${value.datasetId}:${value.documentId}`),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe('dataset-1:document-1')
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,55 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { SegmentListContext, useSegmentListContext } from '../segment-list-context'
|
||||
|
||||
describe('SegmentListContext', () => {
|
||||
it('should expose the default collapsed state', () => {
|
||||
const { result } = renderHook(() => useSegmentListContext(value => value))
|
||||
|
||||
expect(result.current).toEqual({
|
||||
isCollapsed: true,
|
||||
fullScreen: false,
|
||||
toggleFullScreen: expect.any(Function),
|
||||
currSegment: { showModal: false },
|
||||
currChildChunk: { showModal: false },
|
||||
})
|
||||
})
|
||||
|
||||
it('should select provider values from the current segment list context', () => {
|
||||
const toggleFullScreen = vi.fn()
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<SegmentListContext.Provider value={{
|
||||
isCollapsed: false,
|
||||
fullScreen: true,
|
||||
toggleFullScreen,
|
||||
currSegment: {
|
||||
showModal: true,
|
||||
isEditMode: true,
|
||||
segInfo: { id: 'segment-1' } as never,
|
||||
},
|
||||
currChildChunk: {
|
||||
showModal: true,
|
||||
childChunkInfo: { id: 'child-1' } as never,
|
||||
},
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</SegmentListContext.Provider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useSegmentListContext(value => ({
|
||||
fullScreen: value.fullScreen,
|
||||
segmentOpen: value.currSegment.showModal,
|
||||
childOpen: value.currChildChunk.showModal,
|
||||
})),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toEqual({
|
||||
fullScreen: true,
|
||||
segmentOpen: true,
|
||||
childOpen: true,
|
||||
})
|
||||
})
|
||||
})
|
||||
125
web/app/components/develop/hooks/__tests__/use-doc-toc.spec.tsx
Normal file
125
web/app/components/develop/hooks/__tests__/use-doc-toc.spec.tsx
Normal file
@ -0,0 +1,125 @@
|
||||
import type { TocItem } from '../use-doc-toc'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useDocToc } from '../use-doc-toc'
|
||||
|
||||
const mockMatchMedia = (matches: boolean) => {
|
||||
vi.stubGlobal('matchMedia', vi.fn().mockImplementation((query: string) => ({
|
||||
matches,
|
||||
media: query,
|
||||
onchange: null,
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
addListener: vi.fn(),
|
||||
removeListener: vi.fn(),
|
||||
dispatchEvent: vi.fn(),
|
||||
})))
|
||||
}
|
||||
|
||||
const setupDocument = () => {
|
||||
document.body.innerHTML = `
|
||||
<div class="overflow-auto"></div>
|
||||
<article>
|
||||
<h2 id="intro"><a href="#intro">Intro</a></h2>
|
||||
<h2 id="details"><a href="#details">Details</a></h2>
|
||||
</article>
|
||||
`
|
||||
|
||||
const scrollContainer = document.querySelector('.overflow-auto') as HTMLDivElement
|
||||
scrollContainer.scrollTo = vi.fn()
|
||||
|
||||
const intro = document.getElementById('intro') as HTMLElement
|
||||
const details = document.getElementById('details') as HTMLElement
|
||||
|
||||
Object.defineProperty(intro, 'offsetTop', { configurable: true, value: 140 })
|
||||
Object.defineProperty(details, 'offsetTop', { configurable: true, value: 320 })
|
||||
|
||||
return {
|
||||
scrollContainer,
|
||||
intro,
|
||||
details,
|
||||
}
|
||||
}
|
||||
|
||||
describe('useDocToc', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
document.body.innerHTML = ''
|
||||
mockMatchMedia(false)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('should extract headings and expand the TOC on wide screens', async () => {
|
||||
setupDocument()
|
||||
mockMatchMedia(true)
|
||||
|
||||
const { result } = renderHook(() => useDocToc({
|
||||
appDetail: { id: 'app-1' },
|
||||
locale: 'en',
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toEqual<TocItem[]>([
|
||||
{ href: '#intro', text: 'Intro' },
|
||||
{ href: '#details', text: 'Details' },
|
||||
])
|
||||
expect(result.current.activeSection).toBe('intro')
|
||||
expect(result.current.isTocExpanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should update the active section when the scroll container scrolls', async () => {
|
||||
const { scrollContainer, intro, details } = setupDocument()
|
||||
Object.defineProperty(window, 'innerHeight', { configurable: true, value: 800 })
|
||||
|
||||
intro.getBoundingClientRect = vi.fn(() => ({ top: 500 } as DOMRect))
|
||||
details.getBoundingClientRect = vi.fn(() => ({ top: 300 } as DOMRect))
|
||||
|
||||
const { result } = renderHook(() => useDocToc({
|
||||
appDetail: { id: 'app-1' },
|
||||
locale: 'en',
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
act(() => {
|
||||
scrollContainer.dispatchEvent(new Event('scroll'))
|
||||
})
|
||||
|
||||
expect(result.current.activeSection).toBe('details')
|
||||
})
|
||||
|
||||
it('should scroll the container to the clicked heading offset', async () => {
|
||||
const { scrollContainer } = setupDocument()
|
||||
const { result } = renderHook(() => useDocToc({
|
||||
appDetail: { id: 'app-1' },
|
||||
locale: 'en',
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
const preventDefault = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleTocClick(
|
||||
{ preventDefault } as unknown as React.MouseEvent<HTMLAnchorElement>,
|
||||
{ href: '#details', text: 'Details' },
|
||||
)
|
||||
})
|
||||
|
||||
expect(preventDefault).toHaveBeenCalledTimes(1)
|
||||
expect(scrollContainer.scrollTo).toHaveBeenCalledWith({
|
||||
top: 240,
|
||||
behavior: 'smooth',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,69 @@
|
||||
import type { SearchResult } from '../types'
|
||||
import { ragPipelineNodesAction } from '../rag-pipeline-nodes'
|
||||
import { workflowNodesAction } from '../workflow-nodes'
|
||||
|
||||
describe('workflowNodesAction', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
workflowNodesAction.searchFn = undefined
|
||||
})
|
||||
|
||||
it('should return an empty result when no workflow search function is registered', async () => {
|
||||
await expect(workflowNodesAction.search('@node llm', 'llm', 'en')).resolves.toEqual([])
|
||||
})
|
||||
|
||||
it('should delegate to the injected workflow search function', async () => {
|
||||
const results: SearchResult[] = [
|
||||
{ id: 'workflow-node-1', title: 'LLM', type: 'workflow-node', data: {} as never },
|
||||
]
|
||||
workflowNodesAction.searchFn = vi.fn().mockReturnValue(results)
|
||||
|
||||
await expect(workflowNodesAction.search('@node llm', 'llm', 'en')).resolves.toEqual(results)
|
||||
expect(workflowNodesAction.searchFn).toHaveBeenCalledWith('llm')
|
||||
})
|
||||
|
||||
it('should warn and return an empty list when workflow node search throws', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
workflowNodesAction.searchFn = vi.fn(() => {
|
||||
throw new Error('failed')
|
||||
})
|
||||
|
||||
await expect(workflowNodesAction.search('@node llm', 'llm', 'en')).resolves.toEqual([])
|
||||
expect(warnSpy).toHaveBeenCalledWith('Workflow nodes search failed:', expect.any(Error))
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ragPipelineNodesAction', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
ragPipelineNodesAction.searchFn = undefined
|
||||
})
|
||||
|
||||
it('should return an empty result when no rag pipeline search function is registered', async () => {
|
||||
await expect(ragPipelineNodesAction.search('@node embed', 'embed', 'en')).resolves.toEqual([])
|
||||
})
|
||||
|
||||
it('should delegate to the injected rag pipeline search function', async () => {
|
||||
const results: SearchResult[] = [
|
||||
{ id: 'rag-node-1', title: 'Retriever', type: 'workflow-node', data: {} as never },
|
||||
]
|
||||
ragPipelineNodesAction.searchFn = vi.fn().mockReturnValue(results)
|
||||
|
||||
await expect(ragPipelineNodesAction.search('@node retrieve', 'retrieve', 'en')).resolves.toEqual(results)
|
||||
expect(ragPipelineNodesAction.searchFn).toHaveBeenCalledWith('retrieve')
|
||||
})
|
||||
|
||||
it('should warn and return an empty list when rag pipeline node search throws', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
ragPipelineNodesAction.searchFn = vi.fn(() => {
|
||||
throw new Error('failed')
|
||||
})
|
||||
|
||||
await expect(ragPipelineNodesAction.search('@node retrieve', 'retrieve', 'en')).resolves.toEqual([])
|
||||
expect(warnSpy).toHaveBeenCalledWith('RAG pipeline nodes search failed:', expect.any(Error))
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,124 @@
|
||||
import type { SearchResult } from '../../types'
|
||||
import { render } from '@testing-library/react'
|
||||
import { slashAction, SlashCommandProvider } from '../slash'
|
||||
|
||||
const {
|
||||
mockSetTheme,
|
||||
mockSetLocale,
|
||||
mockExecuteCommand,
|
||||
mockRegister,
|
||||
mockSearch,
|
||||
mockUnregister,
|
||||
} = vi.hoisted(() => ({
|
||||
mockSetTheme: vi.fn(),
|
||||
mockSetLocale: vi.fn(),
|
||||
mockExecuteCommand: vi.fn(),
|
||||
mockRegister: vi.fn(),
|
||||
mockSearch: vi.fn(),
|
||||
mockUnregister: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('next-themes', () => ({
|
||||
useTheme: () => ({
|
||||
setTheme: mockSetTheme,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
getI18n: () => ({
|
||||
language: 'ja',
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config', () => ({
|
||||
setLocaleOnClient: mockSetLocale,
|
||||
}))
|
||||
|
||||
vi.mock('../command-bus', () => ({
|
||||
executeCommand: (...args: unknown[]) => mockExecuteCommand(...args),
|
||||
}))
|
||||
|
||||
vi.mock('../registry', () => ({
|
||||
slashCommandRegistry: {
|
||||
register: (...args: unknown[]) => mockRegister(...args),
|
||||
search: (...args: unknown[]) => mockSearch(...args),
|
||||
unregister: (...args: unknown[]) => mockUnregister(...args),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('slashAction', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should expose translated title and description', () => {
|
||||
expect(slashAction.title).toBe('gotoAnything.actions.slashTitle')
|
||||
expect(slashAction.description).toBe('gotoAnything.actions.slashDesc')
|
||||
})
|
||||
|
||||
it('should execute command results and ignore non-command results', () => {
|
||||
slashAction.action?.({
|
||||
id: 'cmd-1',
|
||||
title: 'Command',
|
||||
type: 'command',
|
||||
data: {
|
||||
command: 'navigation.docs',
|
||||
args: { path: '/docs' },
|
||||
},
|
||||
} as SearchResult)
|
||||
|
||||
slashAction.action?.({
|
||||
id: 'app-1',
|
||||
title: 'App',
|
||||
type: 'app',
|
||||
data: {} as never,
|
||||
} as SearchResult)
|
||||
|
||||
expect(mockExecuteCommand).toHaveBeenCalledTimes(1)
|
||||
expect(mockExecuteCommand).toHaveBeenCalledWith('navigation.docs', { path: '/docs' })
|
||||
})
|
||||
|
||||
it('should delegate search to the slash command registry with the active language', async () => {
|
||||
mockSearch.mockResolvedValue([{ id: 'theme', title: '/theme', type: 'command', data: { command: 'theme' } }])
|
||||
|
||||
const results = await slashAction.search('/theme dark', 'dark')
|
||||
|
||||
expect(mockSearch).toHaveBeenCalledWith('/theme dark', 'ja')
|
||||
expect(results).toEqual([{ id: 'theme', title: '/theme', type: 'command', data: { command: 'theme' } }])
|
||||
})
|
||||
})
|
||||
|
||||
describe('SlashCommandProvider', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should register commands on mount and unregister them on unmount', () => {
|
||||
const { unmount } = render(<SlashCommandProvider />)
|
||||
|
||||
expect(mockRegister.mock.calls.map(call => call[0].name)).toEqual([
|
||||
'theme',
|
||||
'language',
|
||||
'forum',
|
||||
'docs',
|
||||
'community',
|
||||
'account',
|
||||
'zen',
|
||||
])
|
||||
expect(mockRegister).toHaveBeenCalledWith(expect.objectContaining({ name: 'theme' }), { setTheme: mockSetTheme })
|
||||
expect(mockRegister).toHaveBeenCalledWith(expect.objectContaining({ name: 'language' }), { setLocale: mockSetLocale })
|
||||
|
||||
unmount()
|
||||
|
||||
expect(mockUnregister.mock.calls.map(call => call[0])).toEqual([
|
||||
'theme',
|
||||
'language',
|
||||
'forum',
|
||||
'docs',
|
||||
'community',
|
||||
'account',
|
||||
'zen',
|
||||
])
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,28 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ExternalLinkIndicator, MenuItemContent } from '../menu-item-content'
|
||||
|
||||
describe('MenuItemContent', () => {
|
||||
it('should render the icon, label, and trailing content', () => {
|
||||
const { container } = render(
|
||||
<MenuItemContent
|
||||
iconClassName="i-ri-settings-4-line"
|
||||
label="Settings"
|
||||
trailing={<span data-testid="menu-trailing">Soon</span>}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Settings')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('menu-trailing')).toHaveTextContent('Soon')
|
||||
expect(container.querySelector('.i-ri-settings-4-line')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExternalLinkIndicator', () => {
|
||||
it('should render the external-link icon with aria-hidden semantics', () => {
|
||||
const { container } = render(<ExternalLinkIndicator />)
|
||||
|
||||
const indicator = container.querySelector('.i-ri-arrow-right-up-line')
|
||||
expect(indicator).toBeInTheDocument()
|
||||
expect(indicator).toHaveAttribute('aria-hidden')
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,23 @@
|
||||
import * as ModelAuth from '../index'
|
||||
|
||||
vi.mock('../add-credential-in-load-balancing', () => ({ default: 'AddCredentialInLoadBalancing' }))
|
||||
vi.mock('../add-custom-model', () => ({ default: 'AddCustomModel' }))
|
||||
vi.mock('../authorized', () => ({ default: 'Authorized' }))
|
||||
vi.mock('../config-model', () => ({ default: 'ConfigModel' }))
|
||||
vi.mock('../credential-selector', () => ({ default: 'CredentialSelector' }))
|
||||
vi.mock('../manage-custom-model-credentials', () => ({ default: 'ManageCustomModelCredentials' }))
|
||||
vi.mock('../switch-credential-in-load-balancing', () => ({ default: 'SwitchCredentialInLoadBalancing' }))
|
||||
|
||||
describe('model-auth index exports', () => {
|
||||
it('should re-export the model auth entry points', () => {
|
||||
expect(ModelAuth).toMatchObject({
|
||||
AddCredentialInLoadBalancing: 'AddCredentialInLoadBalancing',
|
||||
AddCustomModel: 'AddCustomModel',
|
||||
Authorized: 'Authorized',
|
||||
ConfigModel: 'ConfigModel',
|
||||
CredentialSelector: 'CredentialSelector',
|
||||
ManageCustomModelCredentials: 'ManageCustomModelCredentials',
|
||||
SwitchCredentialInLoadBalancing: 'SwitchCredentialInLoadBalancing',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,18 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import CreditsFallbackAlert from '../credits-fallback-alert'
|
||||
|
||||
describe('CreditsFallbackAlert', () => {
|
||||
it('should render the credential fallback copy and description when credentials exist', () => {
|
||||
render(<CreditsFallbackAlert hasCredentials />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.card.apiKeyUnavailableFallback')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.card.apiKeyUnavailableFallbackDescription')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the no-credentials fallback copy without the description', () => {
|
||||
render(<CreditsFallbackAlert hasCredentials={false} />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.card.noApiKeysFallback')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.card.apiKeyUnavailableFallbackDescription')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,16 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import DownloadingIcon from '../downloading-icon'
|
||||
|
||||
describe('DownloadingIcon', () => {
|
||||
it('should render the animated install icon wrapper and svg markup', () => {
|
||||
const { container } = render(<DownloadingIcon />)
|
||||
|
||||
const wrapper = container.firstElementChild as HTMLElement
|
||||
const svg = container.querySelector('svg.install-icon')
|
||||
|
||||
expect(wrapper).toHaveClass('inline-flex', 'text-components-button-secondary-text')
|
||||
expect(svg).toBeInTheDocument()
|
||||
expect(svg).toHaveAttribute('viewBox', '0 0 24 24')
|
||||
expect(svg?.querySelectorAll('path')).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,219 @@
|
||||
import type { TextGenerationRunControl } from '../types'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import TextGeneration from '../index'
|
||||
|
||||
const {
|
||||
mockMode,
|
||||
mockMedia,
|
||||
mockAppStateRef,
|
||||
mockBatchStateRef,
|
||||
sidebarPropsSpy,
|
||||
resultPanelPropsSpy,
|
||||
mockSetIsCallBatchAPI,
|
||||
mockResetBatchExecution,
|
||||
mockHandleRunBatch,
|
||||
} = vi.hoisted(() => ({
|
||||
mockMode: { value: 'create' },
|
||||
mockMedia: { value: 'pc' },
|
||||
mockAppStateRef: { value: null as unknown },
|
||||
mockBatchStateRef: { value: null as unknown },
|
||||
sidebarPropsSpy: vi.fn(),
|
||||
resultPanelPropsSpy: vi.fn(),
|
||||
mockSetIsCallBatchAPI: vi.fn(),
|
||||
mockResetBatchExecution: vi.fn(),
|
||||
mockHandleRunBatch: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
MediaType: {
|
||||
mobile: 'mobile',
|
||||
pc: 'pc',
|
||||
tablet: 'tablet',
|
||||
},
|
||||
default: () => mockMedia.value,
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => key === 'mode' ? mockMode.value : null,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/loading', () => ({
|
||||
default: ({ type }: { type: string }) => <div data-testid="loading-app">{type}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('../hooks/use-text-generation-app-state', () => ({
|
||||
useTextGenerationAppState: () => mockAppStateRef.value,
|
||||
}))
|
||||
|
||||
vi.mock('../hooks/use-text-generation-batch', () => ({
|
||||
useTextGenerationBatch: () => mockBatchStateRef.value,
|
||||
}))
|
||||
|
||||
vi.mock('../text-generation-sidebar', () => ({
|
||||
default: (props: {
|
||||
currentTab: string
|
||||
onRunOnceSend: () => void
|
||||
onBatchSend: (data: string[][]) => void
|
||||
}) => {
|
||||
sidebarPropsSpy(props)
|
||||
return (
|
||||
<div data-testid="sidebar">
|
||||
<span data-testid="sidebar-current-tab">{props.currentTab}</span>
|
||||
<button type="button" onClick={props.onRunOnceSend}>run-once</button>
|
||||
<button type="button" onClick={() => props.onBatchSend([['name'], ['Alice']])}>run-batch</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../text-generation-result-panel', () => ({
|
||||
default: (props: {
|
||||
allTaskList: unknown[]
|
||||
controlSend: number
|
||||
controlStopResponding: number
|
||||
isShowResultPanel: boolean
|
||||
onRunControlChange: (value: TextGenerationRunControl | null) => void
|
||||
onRunStart: () => void
|
||||
}) => {
|
||||
resultPanelPropsSpy(props)
|
||||
return (
|
||||
<div data-testid="result-panel">
|
||||
<span data-testid="show-result">{props.isShowResultPanel ? 'shown' : 'hidden'}</span>
|
||||
<span data-testid="control-send">{String(props.controlSend)}</span>
|
||||
<span data-testid="control-stop">{String(props.controlStopResponding)}</span>
|
||||
<span data-testid="task-count">{String(props.allTaskList.length)}</span>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => props.onRunControlChange({ isStopping: false, onStop: vi.fn() })}
|
||||
>
|
||||
set-run-control
|
||||
</button>
|
||||
<button type="button" onClick={props.onRunStart}>start-run</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const createAppState = (overrides: Record<string, unknown> = {}) => ({
|
||||
accessMode: AccessMode.PUBLIC,
|
||||
appId: 'app-1',
|
||||
appSourceType: 'webApp',
|
||||
customConfig: {
|
||||
remove_webapp_brand: false,
|
||||
replace_webapp_logo: '',
|
||||
},
|
||||
handleRemoveSavedMessage: vi.fn(),
|
||||
handleSaveMessage: vi.fn(),
|
||||
moreLikeThisConfig: { enabled: true },
|
||||
promptConfig: {
|
||||
prompt_template: '',
|
||||
prompt_variables: [{ key: 'name', name: 'Name', type: 'string', required: true }],
|
||||
},
|
||||
savedMessages: [],
|
||||
siteInfo: {
|
||||
title: 'Generator',
|
||||
description: 'Description',
|
||||
},
|
||||
systemFeatures: {},
|
||||
textToSpeechConfig: { enabled: true },
|
||||
visionConfig: { enabled: false },
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createBatchState = (overrides: Record<string, unknown> = {}) => ({
|
||||
allFailedTaskList: [],
|
||||
allSuccessTaskList: [],
|
||||
allTaskList: [],
|
||||
allTasksRun: true,
|
||||
controlRetry: 0,
|
||||
exportRes: [],
|
||||
handleCompleted: vi.fn(),
|
||||
handleRetryAllFailedTask: vi.fn(),
|
||||
handleRunBatch: (data: string[][], options: { onStart: () => void }) => {
|
||||
mockHandleRunBatch(data, options)
|
||||
options.onStart()
|
||||
return true
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
noPendingTask: true,
|
||||
resetBatchExecution: () => mockResetBatchExecution(),
|
||||
setIsCallBatchAPI: (value: boolean) => mockSetIsCallBatchAPI(value),
|
||||
showTaskList: [],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('TextGeneration', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
mockMode.value = 'create'
|
||||
mockMedia.value = 'pc'
|
||||
mockAppStateRef.value = createAppState()
|
||||
mockBatchStateRef.value = createBatchState()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should render the loading state until app state is ready', () => {
|
||||
mockAppStateRef.value = createAppState({ appId: '', siteInfo: null, promptConfig: null })
|
||||
|
||||
render(<TextGeneration />)
|
||||
|
||||
expect(screen.getByTestId('loading-app')).toHaveTextContent('app')
|
||||
})
|
||||
|
||||
it('should fall back to create mode for unsupported query params and keep installed-app layout classes', () => {
|
||||
mockMode.value = 'unsupported'
|
||||
|
||||
const { container } = render(<TextGeneration isInstalledApp />)
|
||||
|
||||
expect(screen.getByTestId('sidebar-current-tab')).toHaveTextContent('create')
|
||||
expect(sidebarPropsSpy).toHaveBeenCalledWith(expect.objectContaining({
|
||||
currentTab: 'create',
|
||||
isInstalledApp: true,
|
||||
isPC: true,
|
||||
}))
|
||||
|
||||
const root = container.firstElementChild as HTMLElement
|
||||
expect(root).toHaveClass('flex', 'h-full', 'rounded-2xl', 'shadow-md')
|
||||
})
|
||||
|
||||
it('should orchestrate a run-once request and reveal the result panel', async () => {
|
||||
render(<TextGeneration />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'run-once' }))
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(mockSetIsCallBatchAPI).toHaveBeenCalledWith(false)
|
||||
expect(mockResetBatchExecution).toHaveBeenCalledTimes(1)
|
||||
expect(screen.getByTestId('show-result')).toHaveTextContent('shown')
|
||||
expect(Number(screen.getByTestId('control-send').textContent)).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should orchestrate batch runs through the batch hook and expose the result panel', async () => {
|
||||
mockMode.value = 'batch'
|
||||
|
||||
render(<TextGeneration />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'run-batch' }))
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(mockHandleRunBatch).toHaveBeenCalledWith(
|
||||
[['name'], ['Alice']],
|
||||
expect.objectContaining({ onStart: expect.any(Function) }),
|
||||
)
|
||||
expect(screen.getByTestId('show-result')).toHaveTextContent('shown')
|
||||
expect(Number(screen.getByTestId('control-stop').textContent)).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,41 @@
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { useIsChatMode } from '../use-is-chat-mode'
|
||||
|
||||
const { mockStoreState } = vi.hoisted(() => ({
|
||||
mockStoreState: {
|
||||
appDetail: undefined as { mode?: AppModeEnum } | undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: (selector: (state: typeof mockStoreState) => unknown) => selector(mockStoreState),
|
||||
}))
|
||||
|
||||
describe('useIsChatMode', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockStoreState.appDetail = undefined
|
||||
})
|
||||
|
||||
it('should return true when the app mode is ADVANCED_CHAT', () => {
|
||||
mockStoreState.appDetail = { mode: AppModeEnum.ADVANCED_CHAT }
|
||||
|
||||
const { result } = renderHook(() => useIsChatMode())
|
||||
|
||||
expect(result.current).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when the app mode is not chat or app detail is missing', () => {
|
||||
mockStoreState.appDetail = { mode: AppModeEnum.WORKFLOW }
|
||||
|
||||
const { result, rerender } = renderHook(() => useIsChatMode())
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
|
||||
mockStoreState.appDetail = undefined
|
||||
rerender()
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
})
|
||||
@ -1890,12 +1890,6 @@
|
||||
}
|
||||
},
|
||||
"app/components/base/chat/chat/index.tsx": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
}
|
||||
|
||||
@ -61,13 +61,15 @@ const createResponseFromHTTPError = (error: HTTPError): Response => {
|
||||
const afterResponseErrorCode = (otherOptions: IOtherOptions): AfterResponseHook => {
|
||||
return async ({ response }) => {
|
||||
if (!/^([23])\d{2}$/.test(String(response.status))) {
|
||||
const errorData = await response.clone()
|
||||
.json()
|
||||
.then(data => data as ResponseError)
|
||||
.catch(() => null)
|
||||
let errorData: ResponseError | null = null
|
||||
try {
|
||||
const data: unknown = await response.clone().json()
|
||||
errorData = data as ResponseError
|
||||
}
|
||||
catch {}
|
||||
const shouldNotifyError = response.status !== 401 && errorData && !otherOptions.silent
|
||||
|
||||
if (shouldNotifyError)
|
||||
if (shouldNotifyError && errorData)
|
||||
toast.error(errorData.message)
|
||||
|
||||
if (response.status === 403 && errorData?.code === 'already_setup')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user